kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rsiva...@apache.org
Subject [kafka] branch trunk updated: KAFKA-10554; Perform follower truncation based on diverging epochs in Fetch response (#9382)
Date Thu, 03 Dec 2020 10:13:29 GMT
This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 7ecc3a5  KAFKA-10554; Perform follower truncation based on diverging epochs in Fetch response (#9382)
7ecc3a5 is described below

commit 7ecc3a579a4b13e0cef4bd3129982ea3bc1a9341
Author: Rajini Sivaram <rajinisivaram@googlemail.com>
AuthorDate: Thu Dec 3 10:12:06 2020 +0000

    KAFKA-10554; Perform follower truncation based on diverging epochs in Fetch response (#9382)
    
    From IBP 2.7 onwards, fetch responses include diverging epoch and offset in fetch responses if lastFetchedEpoch is provided in the fetch request. This PR uses that information for truncation and avoids the additional OffsetForLeaderEpoch requests in followers when lastFetchedEpoch is known.
    
    Co-authored-by: Jason Gustafson <jason@confluent.io>
    
    Reviewers: Jason Gustafson <jason@confluent.io>, Nikhil Bhatia <rite2nikhil@gmail.com>
---
 core/src/main/scala/kafka/api/ApiVersion.scala     |   2 +
 core/src/main/scala/kafka/log/Log.scala            |  16 +-
 .../kafka/server/AbstractFetcherManager.scala      |  32 ++--
 .../scala/kafka/server/AbstractFetcherThread.scala | 117 +++++++++----
 .../kafka/server/ReplicaAlterLogDirsManager.scala  |   2 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  13 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |  17 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |  17 +-
 .../server/DynamicBrokerReconfigurationTest.scala  |   4 +-
 .../kafka/server/AbstractFetcherManagerTest.scala  |   8 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   | 166 ++++++++++++++----
 .../AbstractFetcherThreadWithIbp26Test.scala       |  23 +++
 .../server/ReplicaAlterLogDirsThreadTest.scala     |  49 +++---
 .../kafka/server/ReplicaFetcherThreadTest.scala    | 192 +++++++++++++++++----
 .../unit/kafka/server/ReplicaManagerTest.scala     |  28 ++-
 ...eplicationProtocolAcceptanceWithIbp26Test.scala |  29 ++++
 .../util/ReplicaFetcherMockBlockingSend.scala      |  18 +-
 .../jmh/fetcher/ReplicaFetcherThreadBenchmark.java |   7 +-
 18 files changed, 565 insertions(+), 175 deletions(-)

diff --git a/core/src/main/scala/kafka/api/ApiVersion.scala b/core/src/main/scala/kafka/api/ApiVersion.scala
index 16e6fb9..fdebc27 100644
--- a/core/src/main/scala/kafka/api/ApiVersion.scala
+++ b/core/src/main/scala/kafka/api/ApiVersion.scala
@@ -130,6 +130,8 @@ object ApiVersion {
 
   val latestVersion: ApiVersion = allVersions.last
 
+  def isTruncationOnFetchSupported(version: ApiVersion): Boolean = version >= KAFKA_2_7_IV1
+
   /**
    * Return the minimum `ApiVersion` that supports `RecordVersion`.
    */
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index 15dc9ce..b325adb 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -50,11 +50,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.collection.{Seq, Set, mutable}
 
 object LogAppendInfo {
-  val UnknownLogAppendInfo = LogAppendInfo(None, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, -1L,
+  val UnknownLogAppendInfo = LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, -1L,
     RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1, offsetsMonotonic = false, -1L)
 
   def unknownLogAppendInfoWithLogStartOffset(logStartOffset: Long): LogAppendInfo =
-    LogAppendInfo(None, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset,
+    LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset,
       RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1,
       offsetsMonotonic = false, -1L)
 
@@ -64,7 +64,7 @@ object LogAppendInfo {
    * in unknownLogAppendInfoWithLogStartOffset, but with additiona fields recordErrors and errorMessage
    */
   def unknownLogAppendInfoWithAdditionalInfo(logStartOffset: Long, recordErrors: Seq[RecordError], errorMessage: String): LogAppendInfo =
-    LogAppendInfo(None, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset,
+    LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset,
       RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1,
       offsetsMonotonic = false, -1L, recordErrors, errorMessage)
 }
@@ -82,6 +82,7 @@ object LeaderHwChange {
  * @param firstOffset The first offset in the message set unless the message format is less than V2 and we are appending
  *                    to the follower.
  * @param lastOffset The last offset in the message set
+ * @param lastLeaderEpoch The partition leader epoch corresponding to the last offset, if available.
  * @param maxTimestamp The maximum timestamp of the message set.
  * @param offsetOfMaxTimestamp The offset of the message with the maximum timestamp.
  * @param logAppendTime The log append time (if used) of the message set, otherwise Message.NoTimestamp
@@ -99,6 +100,7 @@ object LeaderHwChange {
  */
 case class LogAppendInfo(var firstOffset: Option[Long],
                          var lastOffset: Long,
+                         var lastLeaderEpoch: Option[Int],
                          var maxTimestamp: Long,
                          var offsetOfMaxTimestamp: Long,
                          var logAppendTime: Long,
@@ -1383,6 +1385,7 @@ class Log(@volatile private var _dir: File,
     var validBytesCount = 0
     var firstOffset: Option[Long] = None
     var lastOffset = -1L
+    var lastLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH
     var sourceCodec: CompressionCodec = NoCompressionCodec
     var monotonic = true
     var maxTimestamp = RecordBatch.NO_TIMESTAMP
@@ -1415,6 +1418,7 @@ class Log(@volatile private var _dir: File,
 
       // update the last offset seen
       lastOffset = batch.lastOffset
+      lastLeaderEpoch = batch.partitionLeaderEpoch
 
       // Check if the message sizes are valid.
       val batchSize = batch.sizeInBytes
@@ -1446,7 +1450,11 @@ class Log(@volatile private var _dir: File,
 
     // Apply broker-side compression if any
     val targetCodec = BrokerCompressionCodec.getTargetCompressionCodec(config.compressionType, sourceCodec)
-    LogAppendInfo(firstOffset, lastOffset, maxTimestamp, offsetOfMaxTimestamp, RecordBatch.NO_TIMESTAMP, logStartOffset,
+    val lastLeaderEpochOpt: Option[Int] = if (lastLeaderEpoch != RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      Some(lastLeaderEpoch)
+    else
+      None
+    LogAppendInfo(firstOffset, lastOffset, lastLeaderEpochOpt, maxTimestamp, offsetOfMaxTimestamp, RecordBatch.NO_TIMESTAMP, logStartOffset,
       RecordConversionStats.EMPTY, sourceCodec, targetCodec, shallowMessageCount, validBytesCount, monotonic, lastOffsetOfFirstBatch)
   }
 
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
index 5a350e9..3d9189e 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
@@ -17,15 +17,14 @@
 
 package kafka.server
 
-import kafka.utils.Implicits._
-import kafka.utils.Logging
 import kafka.cluster.BrokerEndPoint
 import kafka.metrics.KafkaMetricsGroup
+import kafka.utils.Implicits._
+import kafka.utils.Logging
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.utils.Utils
 
-import scala.collection.mutable
-import scala.collection.{Map, Set}
+import scala.collection.{Map, Set, mutable}
 
 abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: String, clientId: String, numFetchers: Int)
   extends Logging with KafkaMetricsGroup {
@@ -64,11 +63,16 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   def resizeThreadPool(newSize: Int): Unit = {
     def migratePartitions(newSize: Int): Unit = {
       fetcherThreadMap.forKeyValue { (id, thread) =>
-        val removedPartitions = thread.partitionsAndOffsets
-        removeFetcherForPartitions(removedPartitions.keySet)
+        val partitionStates = removeFetcherForPartitions(thread.partitions)
         if (id.fetcherId >= newSize)
           thread.shutdown()
-        addFetcherForPartitions(removedPartitions)
+        val fetchStates = partitionStates.map { case (topicPartition, currentFetchState) =>
+          val initialFetchState = InitialFetchState(thread.sourceBroker,
+            currentLeaderEpoch = currentFetchState.currentLeaderEpoch,
+            initOffset = currentFetchState.fetchOffset)
+          topicPartition -> initialFetchState
+        }
+        addFetcherForPartitions(fetchStates)
       }
     }
     lock synchronized {
@@ -142,29 +146,27 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
             addAndStartFetcherThread(brokerAndFetcherId, brokerIdAndFetcherId)
         }
 
-        val initialOffsetAndEpochs = initialFetchOffsets.map { case (tp, brokerAndInitOffset) =>
-          tp -> OffsetAndEpoch(brokerAndInitOffset.initOffset, brokerAndInitOffset.currentLeaderEpoch)
-        }
-
-        addPartitionsToFetcherThread(fetcherThread, initialOffsetAndEpochs)
+        addPartitionsToFetcherThread(fetcherThread, initialFetchOffsets)
       }
     }
   }
 
   protected def addPartitionsToFetcherThread(fetcherThread: T,
-                                             initialOffsetAndEpochs: collection.Map[TopicPartition, OffsetAndEpoch]): Unit = {
+                                             initialOffsetAndEpochs: collection.Map[TopicPartition, InitialFetchState]): Unit = {
     fetcherThread.addPartitions(initialOffsetAndEpochs)
     info(s"Added fetcher to broker ${fetcherThread.sourceBroker.id} for partitions $initialOffsetAndEpochs")
   }
 
-  def removeFetcherForPartitions(partitions: Set[TopicPartition]): Unit = {
+  def removeFetcherForPartitions(partitions: Set[TopicPartition]): Map[TopicPartition, PartitionFetchState] = {
+    val fetchStates = mutable.Map.empty[TopicPartition, PartitionFetchState]
     lock synchronized {
       for (fetcher <- fetcherThreadMap.values)
-        fetcher.removePartitions(partitions)
+        fetchStates ++= fetcher.removePartitions(partitions)
       failedPartitions.removeAll(partitions)
     }
     if (partitions.nonEmpty)
       info(s"Removed fetcher for partitions $partitions")
+    fetchStates
   }
 
   def shutdownIdleFetcherThreads(): Unit = {
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index edaa6b7..2ef618a 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -99,7 +99,9 @@ abstract class AbstractFetcherThread(name: String,
 
   protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long
 
-  protected def isOffsetForLeaderEpochSupported: Boolean
+  protected val isOffsetForLeaderEpochSupported: Boolean
+
+  protected val isTruncationOnFetchSupported: Boolean
 
   override def shutdown(): Unit = {
     initiateShutdown()
@@ -225,6 +227,14 @@ abstract class AbstractFetcherThread(name: String,
     }
   }
 
+  protected def truncateOnFetchResponse(epochEndOffsets: Map[TopicPartition, EpochEndOffset]): Unit = {
+    inLock(partitionMapLock) {
+      val ResultWithPartitions(fetchOffsets, partitionsWithError) = maybeTruncateToEpochEndOffsets(epochEndOffsets, Map.empty)
+      handlePartitionsWithErrors(partitionsWithError, "truncateOnFetchResponse")
+      updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets)
+    }
+  }
+
   // Visible for testing
   private[server] def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = inLock(partitionMapLock) {
     val fetchOffsets = mutable.HashMap.empty[TopicPartition, OffsetTruncationState]
@@ -253,7 +263,7 @@ abstract class AbstractFetcherThread(name: String,
       leaderEpochOffset.error match {
         case Errors.NONE =>
           val offsetTruncationState = getOffsetTruncationState(tp, leaderEpochOffset)
-          if(doTruncate(tp, offsetTruncationState))
+          if (doTruncate(tp, offsetTruncationState))
             fetchOffsets.put(tp, offsetTruncationState)
 
         case Errors.FENCED_LEADER_EPOCH =>
@@ -294,6 +304,7 @@ abstract class AbstractFetcherThread(name: String,
   private def processFetchRequest(sessionPartitions: util.Map[TopicPartition, FetchRequest.PartitionData],
                                   fetchRequest: FetchRequest.Builder): Unit = {
     val partitionsWithError = mutable.Set[TopicPartition]()
+    val divergingEndOffsets = mutable.Map.empty[TopicPartition, EpochEndOffset]
     var responseData: Map[TopicPartition, FetchData] = Map.empty
 
     try {
@@ -341,11 +352,18 @@ abstract class AbstractFetcherThread(name: String,
                       // ReplicaDirAlterThread may have removed topicPartition from the partitionStates after processing the partition data
                       if (validBytes > 0 && partitionStates.contains(topicPartition)) {
                         // Update partitionStates only if there is no exception during processPartitionData
-                        val newFetchState = PartitionFetchState(nextOffset, Some(lag), currentFetchState.currentLeaderEpoch, state = Fetching)
+                        val newFetchState = PartitionFetchState(nextOffset, Some(lag),
+                          currentFetchState.currentLeaderEpoch, state = Fetching,
+                          logAppendInfo.lastLeaderEpoch)
                         partitionStates.updateAndMoveToEnd(topicPartition, newFetchState)
                         fetcherStats.byteRate.mark(validBytes)
                       }
                     }
+                    if (isTruncationOnFetchSupported) {
+                      partitionData.divergingEpoch.ifPresent { divergingEpoch =>
+                        divergingEndOffsets += topicPartition -> new EpochEndOffset(Errors.NONE, divergingEpoch.epoch, divergingEpoch.endOffset)
+                      }
+                    }
                   } catch {
                     case ime@( _: CorruptRecordException | _: InvalidRecordException) =>
                       // we log the error and continue. This ensures two things
@@ -400,17 +418,24 @@ abstract class AbstractFetcherThread(name: String,
       }
     }
 
+    if (divergingEndOffsets.nonEmpty)
+      truncateOnFetchResponse(divergingEndOffsets)
     if (partitionsWithError.nonEmpty) {
       handlePartitionsWithErrors(partitionsWithError, "processFetchRequest")
     }
   }
 
+  /**
+   * This is used to mark partitions for truncation in ReplicaAlterLogDirsThread after leader
+   * offsets are known.
+   */
   def markPartitionsForTruncation(topicPartition: TopicPartition, truncationOffset: Long): Unit = {
     partitionMapLock.lockInterruptibly()
     try {
       Option(partitionStates.stateValue(topicPartition)).foreach { state =>
         val newState = PartitionFetchState(math.min(truncationOffset, state.fetchOffset),
-          state.lag, state.currentLeaderEpoch, state.delay, state = Truncating)
+          state.lag, state.currentLeaderEpoch, state.delay, state = Truncating,
+          lastFetchedEpoch = None)
         partitionStates.updateAndMoveToEnd(topicPartition, newState)
         partitionMapCond.signalAll()
       }
@@ -426,21 +451,37 @@ abstract class AbstractFetcherThread(name: String,
     warn(s"Partition $topicPartition marked as failed")
   }
 
-  def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]): Set[TopicPartition] = {
+  /**
+   * Returns initial partition fetch state based on current state and the provided `initialFetchState`.
+   * From IBP 2.7 onwards, we can rely on truncation based on diverging data returned in fetch responses.
+   * For older versions, we can skip the truncation step iff the leader epoch matches the existing epoch.
+   */
+  private def partitionFetchState(tp: TopicPartition, initialFetchState: InitialFetchState, currentState: PartitionFetchState): PartitionFetchState = {
+    if (currentState != null && currentState.currentLeaderEpoch == initialFetchState.currentLeaderEpoch) {
+      currentState
+    } else if (initialFetchState.initOffset < 0) {
+      fetchOffsetAndTruncate(tp, initialFetchState.currentLeaderEpoch)
+    } else if (isTruncationOnFetchSupported) {
+      // With old message format, `latestEpoch` will be empty and we use Truncating state
+      // to truncate to high watermark.
+      val lastFetchedEpoch = latestEpoch(tp)
+      val state = if (lastFetchedEpoch.nonEmpty) Fetching else Truncating
+      PartitionFetchState(initialFetchState.initOffset, None, initialFetchState.currentLeaderEpoch,
+          state, lastFetchedEpoch)
+    } else {
+      PartitionFetchState(initialFetchState.initOffset, None, initialFetchState.currentLeaderEpoch,
+        state = Truncating, lastFetchedEpoch = None)
+    }
+  }
+
+  def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState]): Set[TopicPartition] = {
     partitionMapLock.lockInterruptibly()
     try {
       failedPartitions.removeAll(initialFetchStates.keySet)
 
       initialFetchStates.forKeyValue { (tp, initialFetchState) =>
-        // We can skip the truncation step iff the leader epoch matches the existing epoch
         val currentState = partitionStates.stateValue(tp)
-        val updatedState = if (currentState != null && currentState.currentLeaderEpoch == initialFetchState.leaderEpoch) {
-          currentState
-        } else if (initialFetchState.offset < 0) {
-          fetchOffsetAndTruncate(tp, initialFetchState.leaderEpoch)
-        } else {
-          PartitionFetchState(initialFetchState.offset, None, initialFetchState.leaderEpoch, state = Truncating)
-        }
+        val updatedState = partitionFetchState(tp, initialFetchState, currentState)
         partitionStates.updateAndMoveToEnd(tp, updatedState)
       }
 
@@ -460,9 +501,13 @@ abstract class AbstractFetcherThread(name: String,
       .map { case (topicPartition, currentFetchState) =>
         val maybeTruncationComplete = fetchOffsets.get(topicPartition) match {
           case Some(offsetTruncationState) =>
-            val state = if (offsetTruncationState.truncationCompleted) Fetching else Truncating
+            val lastFetchedEpoch = latestEpoch(topicPartition)
+            val state = if (isTruncationOnFetchSupported || offsetTruncationState.truncationCompleted)
+              Fetching
+            else
+              Truncating
             PartitionFetchState(offsetTruncationState.offset, currentFetchState.lag,
-              currentFetchState.currentLeaderEpoch, currentFetchState.delay, state)
+              currentFetchState.currentLeaderEpoch, currentFetchState.delay, state, lastFetchedEpoch)
           case None => currentFetchState
         }
         (topicPartition, maybeTruncationComplete)
@@ -596,7 +641,8 @@ abstract class AbstractFetcherThread(name: String,
       truncate(topicPartition, OffsetTruncationState(leaderEndOffset, truncationCompleted = true))
 
       fetcherLagStats.getAndMaybePut(topicPartition).lag = 0
-      PartitionFetchState(leaderEndOffset, Some(0), currentLeaderEpoch, state = Fetching)
+      PartitionFetchState(leaderEndOffset, Some(0), currentLeaderEpoch,
+        state = Fetching, lastFetchedEpoch = latestEpoch(topicPartition))
     } else {
       /**
        * If the leader's log end offset is greater than the follower's log end offset, there are two possibilities:
@@ -629,7 +675,8 @@ abstract class AbstractFetcherThread(name: String,
 
       val initialLag = leaderEndOffset - offsetToFetch
       fetcherLagStats.getAndMaybePut(topicPartition).lag = initialLag
-      PartitionFetchState(offsetToFetch, Some(initialLag), currentLeaderEpoch, state = Fetching)
+      PartitionFetchState(offsetToFetch, Some(initialLag), currentLeaderEpoch,
+        state = Fetching, lastFetchedEpoch = latestEpoch(topicPartition))
     }
   }
 
@@ -640,7 +687,8 @@ abstract class AbstractFetcherThread(name: String,
         Option(partitionStates.stateValue(partition)).foreach { currentFetchState =>
           if (!currentFetchState.isDelayed) {
             partitionStates.updateAndMoveToEnd(partition, PartitionFetchState(currentFetchState.fetchOffset,
-              currentFetchState.lag, currentFetchState.currentLeaderEpoch, Some(new DelayedItem(delay)), currentFetchState.state))
+              currentFetchState.lag, currentFetchState.currentLeaderEpoch, Some(new DelayedItem(delay)),
+              currentFetchState.state, currentFetchState.lastFetchedEpoch))
           }
         }
       }
@@ -648,13 +696,15 @@ abstract class AbstractFetcherThread(name: String,
     } finally partitionMapLock.unlock()
   }
 
-  def removePartitions(topicPartitions: Set[TopicPartition]): Unit = {
+  def removePartitions(topicPartitions: Set[TopicPartition]): Map[TopicPartition, PartitionFetchState] = {
     partitionMapLock.lockInterruptibly()
     try {
-      topicPartitions.foreach { topicPartition =>
+      topicPartitions.map { topicPartition =>
+        val state = partitionStates.stateValue(topicPartition)
         partitionStates.remove(topicPartition)
         fetcherLagStats.unregister(topicPartition)
-      }
+        topicPartition -> state
+      }.filter(_._2 != null).toMap
     } finally partitionMapLock.unlock()
   }
 
@@ -664,20 +714,17 @@ abstract class AbstractFetcherThread(name: String,
     finally partitionMapLock.unlock()
   }
 
+  def partitions: Set[TopicPartition] = {
+    partitionMapLock.lockInterruptibly()
+    try partitionStates.partitionSet.asScala.toSet
+    finally partitionMapLock.unlock()
+  }
+
   // Visible for testing
   private[server] def fetchState(topicPartition: TopicPartition): Option[PartitionFetchState] = inLock(partitionMapLock) {
     Option(partitionStates.stateValue(topicPartition))
   }
 
-  private[server] def partitionsAndOffsets: Map[TopicPartition, InitialFetchState] = inLock(partitionMapLock) {
-    partitionStates.partitionStateMap.asScala.map { case (topicPartition, currentFetchState) =>
-      val initialFetchState = InitialFetchState(sourceBroker,
-        currentLeaderEpoch = currentFetchState.currentLeaderEpoch,
-        initOffset = currentFetchState.fetchOffset)
-      topicPartition -> initialFetchState
-    }
-  }
-
   protected def toMemoryRecords(records: Records): MemoryRecords = {
     (records: @unchecked) match {
       case r: MemoryRecords => r
@@ -687,7 +734,6 @@ abstract class AbstractFetcherThread(name: String,
         MemoryRecords.readableRecords(buffer)
     }
   }
-
 }
 
 object AbstractFetcherThread {
@@ -769,8 +815,9 @@ case object Truncating extends ReplicaState
 case object Fetching extends ReplicaState
 
 object PartitionFetchState {
-  def apply(offset: Long, lag: Option[Long], currentLeaderEpoch: Int, state: ReplicaState): PartitionFetchState = {
-    PartitionFetchState(offset, lag, currentLeaderEpoch, None, state)
+  def apply(offset: Long, lag: Option[Long], currentLeaderEpoch: Int, state: ReplicaState,
+            lastFetchedEpoch: Option[Int]): PartitionFetchState = {
+    PartitionFetchState(offset, lag, currentLeaderEpoch, None, state, lastFetchedEpoch)
   }
 }
 
@@ -786,7 +833,8 @@ case class PartitionFetchState(fetchOffset: Long,
                                lag: Option[Long],
                                currentLeaderEpoch: Int,
                                delay: Option[DelayedItem],
-                               state: ReplicaState) {
+                               state: ReplicaState,
+                               lastFetchedEpoch: Option[Int]) {
 
   def isReadyForFetch: Boolean = state == Fetching && !isDelayed
 
@@ -799,6 +847,7 @@ case class PartitionFetchState(fetchOffset: Long,
   override def toString: String = {
     s"FetchState(fetchOffset=$fetchOffset" +
       s", currentLeaderEpoch=$currentLeaderEpoch" +
+      s", lastFetchedEpoch=$lastFetchedEpoch" +
       s", state=$state" +
       s", lag=$lag" +
       s", delay=${delay.map(_.delayMs).getOrElse(0)}ms" +
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
index 85eaeaa..b45a766 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
@@ -36,7 +36,7 @@ class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig,
   }
 
   override protected def addPartitionsToFetcherThread(fetcherThread: ReplicaAlterLogDirsThread,
-                                                      initialOffsetAndEpochs: collection.Map[TopicPartition, OffsetAndEpoch]): Unit = {
+                                                      initialOffsetAndEpochs: collection.Map[TopicPartition, InitialFetchState]): Unit = {
     val addedPartitions = fetcherThread.addPartitions(initialOffsetAndEpochs)
     val (addedInitialOffsets, notAddedInitialOffsets) = initialOffsetAndEpochs.partition { case (tp, _) =>
       addedPartitions.contains(tp)
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 91b873f..0812f87 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -36,6 +36,7 @@ import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest, FetchResp
 
 import scala.jdk.CollectionConverters._
 import scala.collection.{Map, Seq, Set, mutable}
+import scala.compat.java8.OptionConverters._
 
 class ReplicaAlterLogDirsThread(name: String,
                                 sourceBroker: BrokerEndPoint,
@@ -131,7 +132,7 @@ class ReplicaAlterLogDirsThread(name: String,
     logAppendInfo
   }
 
-  override def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]): Set[TopicPartition] = {
+  override def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState]): Set[TopicPartition] = {
     partitionMapLock.lockInterruptibly()
     try {
       // It is possible that the log dir fetcher completed just before this call, so we
@@ -181,7 +182,9 @@ class ReplicaAlterLogDirsThread(name: String,
     }
   }
 
-  override protected def isOffsetForLeaderEpochSupported: Boolean = true
+  override protected val isOffsetForLeaderEpochSupported: Boolean = true
+
+  override protected val isTruncationOnFetchSupported: Boolean = false
 
   /**
    * Truncate the log for each partition based on current replica's returned epoch and offset.
@@ -248,8 +251,12 @@ class ReplicaAlterLogDirsThread(name: String,
 
     try {
       val logStartOffset = replicaMgr.futureLocalLogOrException(tp).logStartOffset
+      val lastFetchedEpoch = if (isTruncationOnFetchSupported)
+        fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
+      else
+        Optional.empty[Integer]
       requestMap.put(tp, new FetchRequest.PartitionData(fetchState.fetchOffset, logStartOffset,
-        fetchSize, Optional.of(fetchState.currentLeaderEpoch)))
+        fetchSize, Optional.of(fetchState.currentLeaderEpoch), lastFetchedEpoch))
     } catch {
       case e: KafkaStorageException =>
         debug(s"Failed to build fetch for $tp", e)
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 92765c9..2db3d6a 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -39,6 +39,7 @@ import org.apache.kafka.common.utils.{LogContext, Time}
 
 import scala.jdk.CollectionConverters._
 import scala.collection.{Map, mutable}
+import scala.compat.java8.OptionConverters._
 
 class ReplicaFetcherThread(name: String,
                            fetcherId: Int,
@@ -103,7 +104,8 @@ class ReplicaFetcherThread(name: String,
   private val minBytes = brokerConfig.replicaFetchMinBytes
   private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
   private val fetchSize = brokerConfig.replicaFetchMaxBytes
-  private val brokerSupportsLeaderEpochRequest = brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV2
+  override protected val isOffsetForLeaderEpochSupported: Boolean = brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV2
+  override protected val isTruncationOnFetchSupported = ApiVersion.isTruncationOnFetchSupported(brokerConfig.interBrokerProtocolVersion)
   val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id)
 
   override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
@@ -267,8 +269,16 @@ class ReplicaFetcherThread(name: String,
       if (fetchState.isReadyForFetch && !shouldFollowerThrottle(quota, fetchState, topicPartition)) {
         try {
           val logStartOffset = this.logStartOffset(topicPartition)
+          val lastFetchedEpoch = if (isTruncationOnFetchSupported)
+            fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
+          else
+            Optional.empty[Integer]
           builder.add(topicPartition, new FetchRequest.PartitionData(
-            fetchState.fetchOffset, logStartOffset, fetchSize, Optional.of(fetchState.currentLeaderEpoch)))
+            fetchState.fetchOffset,
+            logStartOffset,
+            fetchSize,
+            Optional.of(fetchState.currentLeaderEpoch),
+            lastFetchedEpoch))
         } catch {
           case _: KafkaStorageException =>
             // The replica has already been marked offline due to log directory failure and the original failure should have already been logged.
@@ -345,9 +355,6 @@ class ReplicaFetcherThread(name: String,
     }
   }
 
-  override def isOffsetForLeaderEpochSupported: Boolean = brokerSupportsLeaderEpochRequest
-
-
   /**
    *  To avoid ISR thrashing, we only throttle a replica on the follower if it's in the throttled replica list,
    *  the quota is exceeded and the replica is not in sync.
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index b9487fe..23da8bb 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -1667,9 +1667,10 @@ class ReplicaManager(val config: KafkaConfig,
         val partitionsToMakeFollowerWithLeaderAndOffset = partitionsToMakeFollower.map { partition =>
           val leader = metadataCache.getAliveBrokers.find(_.id == partition.leaderReplicaIdOpt.get).get
             .brokerEndPoint(config.interBrokerListenerName)
-          val fetchOffset = partition.localLogOrException.highWatermark
+          val log = partition.localLogOrException
+          val fetchOffset = initialFetchOffset(log)
           partition.topicPartition -> InitialFetchState(leader, partition.getLeaderEpoch, fetchOffset)
-       }.toMap
+        }.toMap
 
         replicaFetcherManager.addFetcherForPartitions(partitionsToMakeFollowerWithLeaderAndOffset)
       }
@@ -1691,6 +1692,18 @@ class ReplicaManager(val config: KafkaConfig,
     partitionsToMakeFollower
   }
 
+  /**
+   * From IBP 2.7 onwards, we send latest fetch epoch in the request and truncate if a
+   * diverging epoch is returned in the response, avoiding the need for a separate
+   * OffsetForLeaderEpoch request.
+   */
+  private def initialFetchOffset(log: Log): Long = {
+    if (ApiVersion.isTruncationOnFetchSupported(config.interBrokerProtocolVersion) && log.latestEpoch.nonEmpty)
+      log.logEndOffset
+    else
+      log.highWatermark
+  }
+
   private def maybeShrinkIsr(): Unit = {
     trace("Evaluating ISR list of partitions to see which replicas can be removed from the ISR")
 
diff --git a/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
index dc55c13..5c38fe2 100644
--- a/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
+++ b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
@@ -849,7 +849,9 @@ class DynamicBrokerReconfigurationTest extends ZooKeeperTestHarness with SaslSet
         val fetcherThreads = replicaFetcherManager.fetcherThreadMap.filter(_._2.fetchState(tp).isDefined)
         assertEquals(1, fetcherThreads.size)
         assertEquals(replicaFetcherManager.getFetcherId(tp), fetcherThreads.head._1.fetcherId)
-        assertEquals(Some(truncationOffset), fetcherThreads.head._2.fetchState(tp).map(_.fetchOffset))
+        val thread = fetcherThreads.head._2
+        assertEquals(Some(truncationOffset), thread.fetchState(tp).map(_.fetchOffset))
+        assertEquals(Some(Truncating), thread.fetchState(tp).map(_.state))
       }
     }
   }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index 6f07e44..847449b 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -57,11 +57,11 @@ class AbstractFetcherManagerTest {
       initOffset = fetchOffset)
 
     EasyMock.expect(fetcher.start())
-    EasyMock.expect(fetcher.addPartitions(Map(tp -> OffsetAndEpoch(fetchOffset, leaderEpoch))))
+    EasyMock.expect(fetcher.addPartitions(Map(tp -> initialFetchState)))
         .andReturn(Set(tp))
     EasyMock.expect(fetcher.fetchState(tp))
-      .andReturn(Some(PartitionFetchState(fetchOffset, None, leaderEpoch, Truncating)))
-    EasyMock.expect(fetcher.removePartitions(Set(tp)))
+      .andReturn(Some(PartitionFetchState(fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None)))
+    EasyMock.expect(fetcher.removePartitions(Set(tp))).andReturn(Map.empty)
     EasyMock.expect(fetcher.fetchState(tp)).andReturn(None)
     EasyMock.replay(fetcher)
 
@@ -116,7 +116,7 @@ class AbstractFetcherManagerTest {
       initOffset = fetchOffset)
 
     EasyMock.expect(fetcher.start())
-    EasyMock.expect(fetcher.addPartitions(Map(tp -> OffsetAndEpoch(fetchOffset, leaderEpoch))))
+    EasyMock.expect(fetcher.addPartitions(Map(tp -> initialFetchState)))
         .andReturn(Set(tp))
     EasyMock.expect(fetcher.isThreadFailed).andReturn(true)
     EasyMock.replay(fetcher)
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index 4a5633d..6a430b4 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -31,6 +31,7 @@ import kafka.utils.TestUtils
 import org.apache.kafka.common.KafkaException
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.{FencedLeaderEpochException, UnknownLeaderEpochException}
+import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest}
@@ -44,9 +45,11 @@ import scala.util.Random
 import org.scalatest.Assertions.assertThrows
 
 import scala.collection.mutable.ArrayBuffer
+import scala.compat.java8.OptionConverters._
 
 class AbstractFetcherThreadTest {
 
+  val truncateOnFetch = true
   private val partition1 = new TopicPartition("topic1", 0)
   private val partition2 = new TopicPartition("topic2", 0)
   private val failedPartitions = new FailedPartitions
@@ -63,8 +66,9 @@ class AbstractFetcherThreadTest {
       .batches.asScala.head
   }
 
-  private def offsetAndEpoch(fetchOffset: Long, leaderEpoch: Int): OffsetAndEpoch = {
-    OffsetAndEpoch(offset = fetchOffset, leaderEpoch = leaderEpoch)
+  private def initialFetchState(fetchOffset: Long, leaderEpoch: Int): InitialFetchState = {
+    InitialFetchState(leader = new BrokerEndPoint(0, "localhost", 9092),
+      initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch)
   }
 
   @Test
@@ -74,7 +78,7 @@ class AbstractFetcherThreadTest {
 
     // add one partition to create the consumer lag metric
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
     fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
 
     fetcher.start()
@@ -101,7 +105,7 @@ class AbstractFetcherThreadTest {
 
     // add one partition to create the consumer lag metric
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
     fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
 
     fetcher.doWork()
@@ -122,7 +126,7 @@ class AbstractFetcherThreadTest {
     val fetcher = new MockFetcherThread
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
@@ -142,7 +146,7 @@ class AbstractFetcherThreadTest {
     val fetcher = new MockFetcherThread
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 1,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
@@ -168,7 +172,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 0)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes),
@@ -199,7 +203,7 @@ class AbstractFetcherThreadTest {
     // The replica's leader epoch is ahead of the leader
     val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 1)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 1)), forceTruncation = true)
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes))
     val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
@@ -232,7 +236,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 1)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 1)))
 
     val leaderState = MockFetcherThread.PartitionState(Seq(
       mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -273,7 +277,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(3L, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 5)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 1, new SimpleRecord("a".getBytes)),
@@ -307,7 +311,9 @@ class AbstractFetcherThreadTest {
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
         throw new UnsupportedOperationException
 
-      override protected def isOffsetForLeaderEpochSupported: Boolean = false
+      override protected val isOffsetForLeaderEpochSupported: Boolean = false
+
+      override protected val isTruncationOnFetchSupported: Boolean = false
     }
 
     val replicaLog = Seq(
@@ -317,7 +323,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(highWatermark, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(highWatermark, leaderEpoch = 5)))
 
     fetcher.doWork()
 
@@ -350,7 +356,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(highWatermark, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(highWatermark, leaderEpoch = 5)))
 
     fetcher.doWork()
 
@@ -379,7 +385,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(highWatermark, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(highWatermark, leaderEpoch = 5)))
 
     fetcher.doWork()
 
@@ -401,7 +407,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 5)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 5)), forceTruncation = true)
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 1, new SimpleRecord("a".getBytes)),
@@ -419,7 +425,7 @@ class AbstractFetcherThreadTest {
     assertEquals(1, truncations)
 
     // Add partitions again with the same epoch
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(3L, leaderEpoch = 5)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 5)))
 
     // Verify we did not truncate
     fetcher.doWork()
@@ -441,7 +447,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(3L, leaderEpoch = 4)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 4)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -483,7 +489,7 @@ class AbstractFetcherThreadTest {
     val replicaLog = Seq()
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 4)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 4)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
@@ -510,7 +516,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(3L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 0)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
@@ -520,6 +526,11 @@ class AbstractFetcherThreadTest {
 
     // initial truncation and verify that the log start offset is updated
     fetcher.doWork()
+    if (truncateOnFetch) {
+      // Second iteration required here since first iteration is required to
+      // perform initial truncaton based on diverging epoch.
+      fetcher.doWork()
+    }
     assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state))
     assertEquals(2, replicaState.logStartOffset)
     assertEquals(List(), replicaState.log.toList)
@@ -552,7 +563,7 @@ class AbstractFetcherThreadTest {
 
     val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(3L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 0)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
@@ -594,7 +605,7 @@ class AbstractFetcherThreadTest {
     }
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
@@ -639,7 +650,7 @@ class AbstractFetcherThreadTest {
           // leader epoch changes while fetching epochs from leader
           removePartitions(Set(partition))
           setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = nextLeaderEpochOnFollower))
-          addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = nextLeaderEpochOnFollower)))
+          addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = nextLeaderEpochOnFollower)), forceTruncation = true)
           fetchEpochsFromLeaderOnce = true
         }
         fetchedEpochs
@@ -647,7 +658,7 @@ class AbstractFetcherThreadTest {
     }
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = initialLeaderEpochOnFollower)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = initialLeaderEpochOnFollower)), forceTruncation = true)
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes)))
@@ -691,7 +702,7 @@ class AbstractFetcherThreadTest {
     }
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = initialLeaderEpochOnFollower)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = initialLeaderEpochOnFollower)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes)))
@@ -725,7 +736,7 @@ class AbstractFetcherThreadTest {
     }
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition -> initialFetchState(0L, leaderEpoch = 0)), forceTruncation = true)
     fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
 
     // first round of truncation should throw an exception
@@ -765,11 +776,11 @@ class AbstractFetcherThreadTest {
   private def verifyFetcherThreadHandlingPartitionFailure(fetcher: MockFetcherThread): Unit = {
 
     fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition1 -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition1 -> initialFetchState(0L, leaderEpoch = 0)), forceTruncation = true)
     fetcher.setLeaderState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
 
     fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
-    fetcher.addPartitions(Map(partition2 -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.addPartitions(Map(partition2 -> initialFetchState(0L, leaderEpoch = 0)), forceTruncation = true)
     fetcher.setLeaderState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
 
     // processing data fails for partition1
@@ -787,7 +798,7 @@ class AbstractFetcherThreadTest {
     // simulate a leader change
     fetcher.removePartitions(Set(partition1))
     failedPartitions.removeAll(Set(partition1))
-    fetcher.addPartitions(Map(partition1 -> offsetAndEpoch(0L, leaderEpoch = 1)))
+    fetcher.addPartitions(Map(partition1 -> initialFetchState(0L, leaderEpoch = 1)), forceTruncation = true)
 
     // partition1 added back
     assertEquals(Some(Truncating), fetcher.fetchState(partition1).map(_.state))
@@ -795,6 +806,40 @@ class AbstractFetcherThreadTest {
 
   }
 
+  @Test
+  def testDivergingEpochs(): Unit = {
+    val partition = new TopicPartition("topic", 0)
+    val fetcher = new MockFetcherThread
+
+    val replicaLog = Seq(
+      mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
+      mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
+      mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
+
+    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
+    fetcher.setReplicaState(partition, replicaState)
+    fetcher.addPartitions(Map(partition -> initialFetchState(3L, leaderEpoch = 5)))
+    assertEquals(3L, replicaState.logEndOffset)
+    fetcher.verifyLastFetchedEpoch(partition, expectedEpoch = Some(4))
+
+    val leaderLog = Seq(
+      mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
+      mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
+      mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("d".getBytes)))
+
+    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
+    fetcher.setLeaderState(partition, leaderState)
+
+    fetcher.doWork()
+    fetcher.verifyLastFetchedEpoch(partition, Some(2))
+
+    TestUtils.waitUntilTrue(() => {
+      fetcher.doWork()
+      fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log
+    }, "Failed to reconcile leader and follower logs")
+    fetcher.verifyLastFetchedEpoch(partition, Some(5))
+  }
+
   object MockFetcherThread {
     class PartitionState(var log: mutable.Buffer[RecordBatch],
                          var leaderEpoch: Int,
@@ -826,6 +871,7 @@ class AbstractFetcherThreadTest {
 
     private val replicaPartitionStates = mutable.Map[TopicPartition, PartitionState]()
     private val leaderPartitionStates = mutable.Map[TopicPartition, PartitionState]()
+    private var latestEpochDefault: Option[Int] = Some(0)
 
     def setLeaderState(topicPartition: TopicPartition, state: PartitionState): Unit = {
       leaderPartitionStates.put(topicPartition, state)
@@ -845,11 +891,24 @@ class AbstractFetcherThreadTest {
         throw new IllegalArgumentException(s"Unknown partition $topicPartition"))
     }
 
+    def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState], forceTruncation: Boolean): Set[TopicPartition] = {
+      latestEpochDefault = if (forceTruncation) None else Some(0)
+      val partitions = super.addPartitions(initialFetchStates)
+      latestEpochDefault = Some(0)
+      partitions
+    }
+
     override def processPartitionData(topicPartition: TopicPartition,
                                       fetchOffset: Long,
                                       partitionData: FetchData): Option[LogAppendInfo] = {
       val state = replicaPartitionState(topicPartition)
 
+      if (isTruncationOnFetchSupported && partitionData.divergingEpoch.isPresent) {
+        val divergingEpoch = partitionData.divergingEpoch.get
+        truncateOnFetchResponse(Map(topicPartition -> new EpochEndOffset(Errors.NONE, divergingEpoch.epoch, divergingEpoch.endOffset)))
+        return None
+      }
+
       // Throw exception if the fetchOffset does not match the fetcherThread partition state
       if (fetchOffset != state.logEndOffset)
         throw new RuntimeException(s"Offset mismatch for partition $topicPartition: " +
@@ -860,6 +919,7 @@ class AbstractFetcherThreadTest {
       var maxTimestamp = RecordBatch.NO_TIMESTAMP
       var offsetOfMaxTimestamp = -1L
       var lastOffset = state.logEndOffset
+      var lastEpoch: Option[Int] = None
 
       for (batch <- batches) {
         batch.ensureValid()
@@ -870,6 +930,7 @@ class AbstractFetcherThreadTest {
         state.log.append(batch)
         state.logEndOffset = batch.nextOffset
         lastOffset = batch.lastOffset
+        lastEpoch = Some(batch.partitionLeaderEpoch)
       }
 
       state.logStartOffset = partitionData.logStartOffset
@@ -877,6 +938,7 @@ class AbstractFetcherThreadTest {
 
       Some(LogAppendInfo(firstOffset = Some(fetchOffset),
         lastOffset = lastOffset,
+        lastLeaderEpoch = lastEpoch,
         maxTimestamp = maxTimestamp,
         offsetOfMaxTimestamp = offsetOfMaxTimestamp,
         logAppendTime = Time.SYSTEM.milliseconds(),
@@ -912,8 +974,12 @@ class AbstractFetcherThreadTest {
       partitionMap.foreach { case (partition, state) =>
         if (state.isReadyForFetch) {
           val replicaState = replicaPartitionState(partition)
+          val lastFetchedEpoch = if (isTruncationOnFetchSupported)
+            state.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
+          else
+            Optional.empty[Integer]
           fetchData.put(partition, new FetchRequest.PartitionData(state.fetchOffset, replicaState.logStartOffset,
-            1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch)))
+            1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch), lastFetchedEpoch))
         }
       }
       val fetchRequest = FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion, replicaId, 0, 1, fetchData.asJava)
@@ -922,7 +988,7 @@ class AbstractFetcherThreadTest {
 
     override def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
       val state = replicaPartitionState(topicPartition)
-      state.log.lastOption.map(_.partitionLeaderEpoch).orElse(Some(EpochEndOffset.UNDEFINED_EPOCH))
+      state.log.lastOption.map(_.partitionLeaderEpoch).orElse(latestEpochDefault)
     }
 
     override def logStartOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logStartOffset
@@ -953,6 +1019,31 @@ class AbstractFetcherThreadTest {
       }
     }
 
+    def verifyLastFetchedEpoch(partition: TopicPartition, expectedEpoch: Option[Int]): Unit = {
+      if (isTruncationOnFetchSupported) {
+        assertEquals(Some(Fetching), fetchState(partition).map(_.state))
+        assertEquals(expectedEpoch, fetchState(partition).flatMap(_.lastFetchedEpoch))
+      }
+    }
+
+    private def divergingEpochAndOffset(partition: TopicPartition,
+                                        lastFetchedEpoch: Optional[Integer],
+                                        fetchOffset: Long,
+                                        partitionState: PartitionState): Option[FetchResponseData.EpochEndOffset] = {
+      lastFetchedEpoch.asScala.flatMap { fetchEpoch =>
+        val epochEndOffset = fetchEpochEndOffsets(Map(partition -> new EpochData(Optional.empty[Integer], fetchEpoch)))(partition)
+
+        if (partitionState.log.isEmpty || epochEndOffset.hasUndefinedEpochOrOffset)
+          None
+        else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchOffset) {
+          Some(new FetchResponseData.EpochEndOffset()
+            .setEpoch(epochEndOffset.leaderEpoch)
+            .setEndOffset(epochEndOffset.endOffset))
+        } else
+          None
+      }
+    }
+
     private def lookupEndOffsetForEpoch(epochData: EpochData,
                                         partitionState: PartitionState): EpochEndOffset = {
       checkExpectedLeaderEpoch(epochData.currentLeaderEpoch, partitionState).foreach { error =>
@@ -962,7 +1053,11 @@ class AbstractFetcherThreadTest {
       var epochLowerBound = EpochEndOffset.UNDEFINED_EPOCH
       for (batch <- partitionState.log) {
         if (batch.partitionLeaderEpoch > epochData.leaderEpoch) {
-          return new EpochEndOffset(Errors.NONE, epochLowerBound, batch.baseOffset)
+          // If we don't have the requested epoch, return the next higher entry
+          if (epochLowerBound == EpochEndOffset.UNDEFINED_EPOCH)
+            return new EpochEndOffset(Errors.NONE, batch.partitionLeaderEpoch, batch.baseOffset)
+          else
+            return new EpochEndOffset(Errors.NONE, epochLowerBound, batch.baseOffset)
         }
         epochLowerBound = batch.partitionLeaderEpoch
       }
@@ -979,17 +1074,22 @@ class AbstractFetcherThreadTest {
       endOffsets
     }
 
-    override protected def isOffsetForLeaderEpochSupported: Boolean = true
+    override protected val isOffsetForLeaderEpochSupported: Boolean = true
+
+    override protected val isTruncationOnFetchSupported: Boolean = truncateOnFetch
 
     override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
       fetchRequest.fetchData.asScala.map { case (partition, fetchData) =>
         val leaderState = leaderPartitionState(partition)
         val epochCheckError = checkExpectedLeaderEpoch(fetchData.currentLeaderEpoch, leaderState)
+        val divergingEpoch = divergingEpochAndOffset(partition, fetchData.lastFetchedEpoch, fetchData.fetchOffset, leaderState)
 
         val (error, records) = if (epochCheckError.isDefined) {
           (epochCheckError.get, MemoryRecords.EMPTY)
         } else if (fetchData.fetchOffset > leaderState.logEndOffset || fetchData.fetchOffset < leaderState.logStartOffset) {
           (Errors.OFFSET_OUT_OF_RANGE, MemoryRecords.EMPTY)
+        } else if (divergingEpoch.nonEmpty) {
+          (Errors.NONE, MemoryRecords.EMPTY)
         } else {
           // for simplicity, we fetch only one batch at a time
           val records = leaderState.log.find(_.baseOffset >= fetchData.fetchOffset) match {
@@ -1007,7 +1107,7 @@ class AbstractFetcherThreadTest {
         }
 
         (partition, new FetchData(error, leaderState.highWatermark, leaderState.highWatermark, leaderState.logStartOffset,
-          List.empty.asJava, records))
+          Optional.empty[Integer], List.empty.asJava, divergingEpoch.asJava, records))
       }.toMap
     }
 
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala
new file mode 100644
index 0000000..a852465
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala
@@ -0,0 +1,23 @@
+/**
+  * Licensed to the Apache Software Foundation (ASF) under one or more
+  * contributor license agreements.  See the NOTICE file distributed with
+  * this work for additional information regarding copyright ownership.
+  * The ASF licenses this file to You under the Apache License, Version 2.0
+  * (the "License"); you may not use this file except in compliance with
+  * the License.  You may obtain a copy of the License at
+  *
+  *    http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+
+package kafka.server
+
+class AbstractFetcherThreadWithIbp26Test extends AbstractFetcherThreadTest {
+
+  override val truncateOnFetch = false
+}
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index 2d45b73..2547c0c 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -46,8 +46,9 @@ class ReplicaAlterLogDirsThreadTest {
   private val t1p1 = new TopicPartition("topic1", 1)
   private val failedPartitions = new FailedPartitions
 
-  private def offsetAndEpoch(fetchOffset: Long, leaderEpoch: Int = 1): OffsetAndEpoch = {
-    OffsetAndEpoch(offset = fetchOffset, leaderEpoch = leaderEpoch)
+  private def initialFetchState(fetchOffset: Long, leaderEpoch: Int = 1): InitialFetchState = {
+    InitialFetchState(leader = new BrokerEndPoint(0, "localhost", 9092),
+      initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch)
   }
 
   @Test
@@ -70,7 +71,7 @@ class ReplicaAlterLogDirsThreadTest {
       quota = quotaManager,
       brokerTopicStats = new BrokerTopicStats)
 
-    val addedPartitions = thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L)))
+    val addedPartitions = thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
     assertEquals(Set.empty, addedPartitions)
     assertEquals(0, thread.partitionCount)
     assertEquals(None, thread.fetchState(t1p0))
@@ -131,7 +132,7 @@ class ReplicaAlterLogDirsThreadTest {
       brokerTopicStats = new BrokerTopicStats)
 
     // Initially we add the partition with an older epoch which results in an error
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch - 1)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch - 1)))
     assertTrue(thread.fetchState(t1p0).isDefined)
     assertEquals(1, thread.partitionCount)
 
@@ -142,7 +143,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(0, thread.partitionCount)
 
     // Next we update the epoch and assert that we can continue
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch)))
     assertEquals(Some(leaderEpoch), thread.fetchState(t1p0).map(_.currentLeaderEpoch))
     assertEquals(1, thread.partitionCount)
 
@@ -221,7 +222,7 @@ class ReplicaAlterLogDirsThreadTest {
       quota = quotaManager,
       brokerTopicStats = new BrokerTopicStats)
 
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch)))
     assertTrue(thread.fetchState(t1p0).isDefined)
     assertEquals(1, thread.partitionCount)
 
@@ -421,7 +422,7 @@ class ReplicaAlterLogDirsThreadTest {
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //Run it
     thread.doWork()
@@ -463,7 +464,7 @@ class ReplicaAlterLogDirsThreadTest {
     expect(partition.truncateTo(capture(truncateToCapture), EasyMock.eq(true))).anyTimes()
     expect(futureLog.logEndOffset).andReturn(futureReplicaLEO).anyTimes()
     expect(futureLog.latestEpoch).andReturn(Some(leaderEpoch)).once()
-    expect(futureLog.latestEpoch).andReturn(Some(leaderEpoch - 2)).once()
+    expect(futureLog.latestEpoch).andReturn(Some(leaderEpoch - 2)).times(3)
 
     // leader replica truncated and fetched new offsets with new leader epoch
     expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch, fetchOnlyFromLeader = false))
@@ -494,7 +495,7 @@ class ReplicaAlterLogDirsThreadTest {
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // First run will result in another offset for leader epoch request
     thread.doWork()
@@ -549,7 +550,7 @@ class ReplicaAlterLogDirsThreadTest {
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(initialFetchOffset)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(initialFetchOffset)))
 
     //Run it
     thread.doWork()
@@ -624,7 +625,7 @@ class ReplicaAlterLogDirsThreadTest {
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // Run thread 3 times (exactly number of times we mock exception for getReplicaOrException)
     (0 to 2).foreach { _ =>
@@ -685,7 +686,7 @@ class ReplicaAlterLogDirsThreadTest {
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // loop few times
     (0 to 3).foreach { _ =>
@@ -726,12 +727,12 @@ class ReplicaAlterLogDirsThreadTest {
       quota = quotaManager,
       brokerTopicStats = null)
     thread.addPartitions(Map(
-      t1p0 -> offsetAndEpoch(0L, leaderEpoch),
-      t1p1 -> offsetAndEpoch(0L, leaderEpoch)))
+      t1p0 -> initialFetchState(0L, leaderEpoch),
+      t1p1 -> initialFetchState(0L, leaderEpoch)))
 
     val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map(
-      t1p0 -> PartitionFetchState(150, None, leaderEpoch, None, state = Fetching),
-      t1p1 -> PartitionFetchState(160, None, leaderEpoch, None, state = Fetching)))
+      t1p0 -> PartitionFetchState(150, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
+      t1p1 -> PartitionFetchState(160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None)))
 
     assertTrue(fetchRequestOpt.isDefined)
     val fetchRequest = fetchRequestOpt.get.fetchRequest
@@ -777,13 +778,13 @@ class ReplicaAlterLogDirsThreadTest {
       quota = quotaManager,
       brokerTopicStats = null)
     thread.addPartitions(Map(
-      t1p0 -> offsetAndEpoch(0L, leaderEpoch),
-      t1p1 -> offsetAndEpoch(0L, leaderEpoch)))
+      t1p0 -> initialFetchState(0L, leaderEpoch),
+      t1p1 -> initialFetchState(0L, leaderEpoch)))
 
     // one partition is ready and one is truncating
     val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map(
-        t1p0 -> PartitionFetchState(150, None, leaderEpoch, state = Fetching),
-        t1p1 -> PartitionFetchState(160, None, leaderEpoch, state = Truncating)))
+        t1p0 -> PartitionFetchState(150, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None),
+        t1p1 -> PartitionFetchState(160, None, leaderEpoch, state = Truncating, lastFetchedEpoch = None)))
 
     assertTrue(fetchRequestOpt.isDefined)
     val fetchRequest = fetchRequestOpt.get
@@ -796,8 +797,8 @@ class ReplicaAlterLogDirsThreadTest {
 
     // one partition is ready and one is delayed
     val ResultWithPartitions(fetchRequest2Opt, partitionsWithError2) = thread.buildFetch(Map(
-        t1p0 -> PartitionFetchState(140, None, leaderEpoch, state = Fetching),
-        t1p1 -> PartitionFetchState(160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching)))
+        t1p0 -> PartitionFetchState(140, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None),
+        t1p1 -> PartitionFetchState(160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None)))
 
     assertTrue(fetchRequest2Opt.isDefined)
     val fetchRequest2 = fetchRequest2Opt.get
@@ -810,8 +811,8 @@ class ReplicaAlterLogDirsThreadTest {
 
     // both partitions are delayed
     val ResultWithPartitions(fetchRequest3Opt, partitionsWithError3) = thread.buildFetch(Map(
-        t1p0 -> PartitionFetchState(140, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching),
-        t1p1 -> PartitionFetchState(160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching)))
+        t1p0 -> PartitionFetchState(140, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None),
+        t1p1 -> PartitionFetchState(160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None)))
     assertTrue("Expected no fetch requests since all partitions are delayed", fetchRequest3Opt.isEmpty)
     assertFalse(partitionsWithError3.nonEmpty)
   }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index 3fc5018..ae2dbda 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -19,12 +19,14 @@ package kafka.server
 import java.nio.charset.StandardCharsets
 import java.util.{Collections, Optional}
 
+import kafka.api.{ApiVersion, KAFKA_2_6_IV0}
 import kafka.cluster.{BrokerEndPoint, Partition}
-import kafka.log.{Log, LogManager}
+import kafka.log.{Log, LogAppendInfo, LogManager}
 import kafka.server.QuotaFactory.UnboundedQuota
 import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend
 import kafka.utils.TestUtils
 import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.Errors._
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
@@ -49,8 +51,9 @@ class ReplicaFetcherThreadTest {
   private val brokerEndPoint = new BrokerEndPoint(0, "localhost", 1000)
   private val failedPartitions = new FailedPartitions
 
-  private def offsetAndEpoch(fetchOffset: Long, leaderEpoch: Int = 1): OffsetAndEpoch = {
-    OffsetAndEpoch(offset = fetchOffset, leaderEpoch = leaderEpoch)
+  private def initialFetchState(fetchOffset: Long, leaderEpoch: Int = 1): InitialFetchState = {
+    InitialFetchState(leader = new BrokerEndPoint(0, "localhost", 9092),
+      initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch)
   }
 
   @After
@@ -82,8 +85,8 @@ class ReplicaFetcherThreadTest {
   }
 
   @Test
-  def shouldFetchLeaderEpochRequestIfLastEpochDefinedForSomePartitions(): Unit = {
-    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+  def testFetchLeaderEpochRequestIfLastEpochDefinedForSomePartitions(): Unit = {
+    val config = kafkaConfigNoTruncateOnFetch
 
     //Setup all dependencies
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
@@ -123,11 +126,11 @@ class ReplicaFetcherThreadTest {
 
     val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
 
-    // topic 1 supports epoch, t2 doesn't
+    // topic 1 supports epoch, t2 doesn't.
     thread.addPartitions(Map(
-      t1p0 -> offsetAndEpoch(0L),
-      t1p1 -> offsetAndEpoch(0L),
-      t2p1 -> offsetAndEpoch(0L)))
+      t1p0 -> initialFetchState(0L),
+      t1p1 -> initialFetchState(0L),
+      t2p1 -> initialFetchState(0L)))
 
     assertPartitionStates(thread, shouldBeReadyForFetch = false, shouldBeTruncatingLog = true, shouldBeDelayed = false)
     //Loop 1
@@ -218,8 +221,19 @@ class ReplicaFetcherThreadTest {
   }
 
   @Test
-  def shouldFetchLeaderEpochOnFirstFetchOnlyIfLeaderEpochKnownToBoth(): Unit = {
-    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+  def shouldFetchLeaderEpochOnFirstFetchOnlyIfLeaderEpochKnownToBothIbp26(): Unit = {
+    verifyFetchLeaderEpochOnFirstFetch(KAFKA_2_6_IV0)
+  }
+
+  @Test
+  def shouldNotFetchLeaderEpochOnFirstFetchWithTruncateOnFetch(): Unit = {
+    verifyFetchLeaderEpochOnFirstFetch(ApiVersion.latestVersion, epochFetchCount = 0)
+  }
+
+  private def verifyFetchLeaderEpochOnFirstFetch(ibp: ApiVersion, epochFetchCount: Int = 1): Unit = {
+    val props = TestUtils.createBrokerConfig(1, "localhost:1234")
+    props.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, ibp.version)
+    val config = KafkaConfig.fromProps(props)
 
     //Setup all dependencies
     val logManager: LogManager = createMock(classOf[LogManager])
@@ -253,21 +267,21 @@ class ReplicaFetcherThreadTest {
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
     val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager,
       new Metrics, new SystemTime, UnboundedQuota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //Loop 1
     thread.doWork()
-    assertEquals(1, mockNetwork.epochFetchCount)
+    assertEquals(epochFetchCount, mockNetwork.epochFetchCount)
     assertEquals(1, mockNetwork.fetchCount)
 
     //Loop 2 we should not fetch epochs
     thread.doWork()
-    assertEquals(1, mockNetwork.epochFetchCount)
+    assertEquals(epochFetchCount, mockNetwork.epochFetchCount)
     assertEquals(2, mockNetwork.fetchCount)
 
     //Loop 3 we should not fetch epochs
     thread.doWork()
-    assertEquals(1, mockNetwork.epochFetchCount)
+    assertEquals(epochFetchCount, mockNetwork.epochFetchCount)
     assertEquals(3, mockNetwork.fetchCount)
 
     //Assert that truncate to is called exactly once (despite two loops)
@@ -281,7 +295,7 @@ class ReplicaFetcherThreadTest {
     val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL)
 
     // Setup all the dependencies
-    val configs = TestUtils.createBrokerConfigs(1, "localhost:1234").map(KafkaConfig.fromProps)
+    val config = kafkaConfigNoTruncateOnFetch
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
     val logManager: LogManager = createMock(classOf[LogManager])
     val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager])
@@ -313,9 +327,9 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs.head, failedPartitions, replicaManager,
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager,
       new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t2p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t2p1 -> initialFetchState(0L)))
 
     //Run it
     thread.doWork()
@@ -333,7 +347,7 @@ class ReplicaFetcherThreadTest {
     val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL)
 
     // Setup all the dependencies
-    val configs = TestUtils.createBrokerConfigs(1, "localhost:1234").map(KafkaConfig.fromProps)
+    val config = kafkaConfigNoTruncateOnFetch
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
     val logManager: LogManager = createMock(classOf[LogManager])
     val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager])
@@ -366,9 +380,9 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs.head, failedPartitions,
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions,
       replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t2p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t2p1 -> initialFetchState(0L)))
 
     //Run it
     thread.doWork()
@@ -383,11 +397,10 @@ class ReplicaFetcherThreadTest {
 
   @Test
   def shouldFetchLeaderEpochSecondTimeIfLeaderRepliesWithEpochNotKnownToFollower(): Unit = {
-
     // Create a capture to track what partitions/offsets are truncated
     val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL)
 
-    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+    val config = kafkaConfigNoTruncateOnFetch
 
     // Setup all dependencies
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
@@ -423,7 +436,7 @@ class ReplicaFetcherThreadTest {
     // Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
     val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     // Loop 1 -- both topic partitions will need to fetch another leader epoch
     thread.doWork()
@@ -454,6 +467,107 @@ class ReplicaFetcherThreadTest {
   }
 
   @Test
+  def shouldTruncateIfLeaderRepliesWithDivergingEpochNotKnownToFollower(): Unit = {
+
+    // Create a capture to track what partitions/offsets are truncated
+    val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL)
+
+    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+
+    // Setup all dependencies
+    val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
+    val logManager: LogManager = createMock(classOf[LogManager])
+    val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager])
+    val log: Log = createNiceMock(classOf[Log])
+    val partition: Partition = createNiceMock(classOf[Partition])
+    val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
+
+    val initialLEO = 200
+    var latestLogEpoch: Option[Int] = Some(5)
+
+    // Stubs
+    expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes()
+    expect(partition.localLogOrException).andReturn(log).anyTimes()
+    expect(log.highWatermark).andReturn(115).anyTimes()
+    expect(log.latestEpoch).andAnswer(() => latestLogEpoch).anyTimes()
+    expect(log.endOffsetForEpoch(4)).andReturn(Some(OffsetAndEpoch(149, 4))).anyTimes()
+    expect(log.endOffsetForEpoch(3)).andReturn(Some(OffsetAndEpoch(129, 2))).anyTimes()
+    expect(log.endOffsetForEpoch(2)).andReturn(Some(OffsetAndEpoch(119, 1))).anyTimes()
+    expect(log.logEndOffset).andReturn(initialLEO).anyTimes()
+    expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes()
+    expect(replicaManager.logManager).andReturn(logManager).anyTimes()
+    expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes()
+    expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats]))
+    stub(partition, replicaManager, log)
+
+    replay(replicaManager, logManager, quota, partition, log)
+
+    // Create the fetcher thread
+    val mockNetwork = new ReplicaFetcherMockBlockingSend(Collections.emptyMap(), brokerEndPoint, new SystemTime())
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) {
+      override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None
+    }
+    thread.addPartitions(Map(t1p0 -> initialFetchState(initialLEO), t1p1 -> initialFetchState(initialLEO)))
+    val partitions = Set(t1p0, t1p1)
+
+    // Loop 1 -- both topic partitions skip epoch fetch and send fetch request since we can truncate
+    // later based on diverging epochs in fetch response.
+    thread.doWork()
+    assertEquals(0, mockNetwork.epochFetchCount)
+    assertEquals(1, mockNetwork.fetchCount)
+    partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) }
+
+    def partitionData(divergingEpoch: FetchResponseData.EpochEndOffset): FetchResponse.PartitionData[Records] = {
+      new FetchResponse.PartitionData[Records](
+        Errors.NONE, 0, 0, 0, Optional.empty(), Collections.emptyList(),
+        Optional.of(divergingEpoch), MemoryRecords.EMPTY)
+    }
+
+    // Loop 2 should truncate based on diverging epoch and continue to send fetch requests.
+    mockNetwork.setFetchPartitionDataForNextResponse(Map(
+      t1p0 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(4).setEndOffset(140)),
+      t1p1 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(4).setEndOffset(141))
+    ))
+    latestLogEpoch = Some(4)
+    thread.doWork()
+    assertEquals(0, mockNetwork.epochFetchCount)
+    assertEquals(2, mockNetwork.fetchCount)
+    assertTrue("Expected " + t1p0 + " to truncate to offset 140 (truncation offsets: " + truncateToCapture.getValues + ")",
+      truncateToCapture.getValues.asScala.contains(140))
+    assertTrue("Expected " + t1p1 + " to truncate to offset 141 (truncation offsets: " + truncateToCapture.getValues + ")",
+      truncateToCapture.getValues.asScala.contains(141))
+    partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) }
+
+    // Loop 3 should truncate because of diverging epoch. Offset truncation is not complete
+    // because divergent epoch is not known to follower. We truncate and stay in Fetching state.
+    mockNetwork.setFetchPartitionDataForNextResponse(Map(
+      t1p0 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(130)),
+      t1p1 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(131))
+    ))
+    thread.doWork()
+    assertEquals(0, mockNetwork.epochFetchCount)
+    assertEquals(3, mockNetwork.fetchCount)
+    assertTrue("Expected to truncate to offset 129 (truncation offsets: " + truncateToCapture.getValues + ")",
+      truncateToCapture.getValues.asScala.contains(129))
+    partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) }
+
+    // Loop 4 should truncate because of diverging epoch. Offset truncation is not complete
+    // because divergent epoch is not known to follower. Last fetched epoch cannot be determined
+    // from the log. We truncate and stay in Fetching state.
+    mockNetwork.setFetchPartitionDataForNextResponse(Map(
+      t1p0 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(2).setEndOffset(120)),
+      t1p1 -> partitionData(new FetchResponseData.EpochEndOffset().setEpoch(2).setEndOffset(121))
+    ))
+    latestLogEpoch = None
+    thread.doWork()
+    assertEquals(0, mockNetwork.epochFetchCount)
+    assertEquals(4, mockNetwork.fetchCount)
+    assertTrue("Expected to truncate to offset 119 (truncation offsets: " + truncateToCapture.getValues + ")",
+      truncateToCapture.getValues.asScala.contains(119))
+    partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) }
+  }
+
+  @Test
   def shouldUseLeaderEndOffsetIfInterBrokerVersionBelow20(): Unit = {
 
     // Create a capture to track what partitions/offsets are truncated
@@ -498,7 +612,7 @@ class ReplicaFetcherThreadTest {
     // Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
     val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     // Loop 1 -- both topic partitions will truncate to leader offset even though they don't know
     // about leader epoch
@@ -527,7 +641,7 @@ class ReplicaFetcherThreadTest {
     val truncated: Capture[Long] = newCapture(CaptureType.ALL)
 
     // Setup all the dependencies
-    val configs = TestUtils.createBrokerConfigs(1, "localhost:1234").map(KafkaConfig.fromProps)
+    val config = kafkaConfigNoTruncateOnFetch
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
     val logManager: LogManager = createMock(classOf[LogManager])
     val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager])
@@ -541,7 +655,7 @@ class ReplicaFetcherThreadTest {
     expect(partition.truncateTo(capture(truncated), anyBoolean())).anyTimes()
     expect(partition.localLogOrException).andReturn(log).anyTimes()
     expect(log.highWatermark).andReturn(initialFetchOffset).anyTimes()
-    expect(log.latestEpoch).andReturn(Some(5))
+    expect(log.latestEpoch).andReturn(Some(5)).times(2)
     expect(replicaManager.logManager).andReturn(logManager).anyTimes()
     expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes()
     expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats]))
@@ -553,8 +667,8 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs.head, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(initialFetchOffset)))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(initialFetchOffset)))
 
     //Run it
     thread.doWork()
@@ -570,7 +684,7 @@ class ReplicaFetcherThreadTest {
     val truncated: Capture[Long] = newCapture(CaptureType.ALL)
 
     // Setup all the dependencies
-    val configs = TestUtils.createBrokerConfigs(1, "localhost:1234").map(KafkaConfig.fromProps)
+    val config = kafkaConfigNoTruncateOnFetch
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
     val logManager: LogManager = createMock(classOf[LogManager])
     val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager])
@@ -606,8 +720,8 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs.head, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //Run thread 3 times
     (0 to 3).foreach { _ =>
@@ -628,7 +742,7 @@ class ReplicaFetcherThreadTest {
 
   @Test
   def shouldMovePartitionsOutOfTruncatingLogState(): Unit = {
-    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+    val config = kafkaConfigNoTruncateOnFetch
 
     //Setup all stubs
     val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager])
@@ -663,7 +777,7 @@ class ReplicaFetcherThreadTest {
     val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
 
     //When
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //Then all partitions should start in an TruncatingLog state
     assertEquals(Option(Truncating), thread.fetchState(t1p0).map(_.state))
@@ -679,7 +793,7 @@ class ReplicaFetcherThreadTest {
 
   @Test
   def shouldFilterPartitionsMadeLeaderDuringLeaderEpochRequest(): Unit ={
-    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
+    val config = kafkaConfigNoTruncateOnFetch
     val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL)
     val initialLEO = 100
 
@@ -716,7 +830,7 @@ class ReplicaFetcherThreadTest {
       new SystemTime(), quota, Some(mockNetwork))
 
     //When
-    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
+    thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //When the epoch request is outstanding, remove one of the partitions to simulate a leader change. We do this via a callback passed to the mock thread
     val partitionThatBecameLeader = t1p0
@@ -831,4 +945,10 @@ class ReplicaFetcherThreadTest {
     expect(replicaManager.localLogOrException(t2p1)).andReturn(log).anyTimes()
     expect(replicaManager.nonOfflinePartition(t2p1)).andReturn(Some(partition)).anyTimes()
   }
+
+  private def kafkaConfigNoTruncateOnFetch: KafkaConfig = {
+    val props = TestUtils.createBrokerConfig(1, "localhost:1234")
+    props.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_6_IV0.version)
+    KafkaConfig.fromProps(props)
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 761b58f..b95cf0c 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -781,12 +781,26 @@ class ReplicaManagerTest {
     }
   }
 
-  /**
-    * If a partition becomes a follower and the leader is unchanged it should check for truncation
-    * if the epoch has increased by more than one (which suggests it has missed an update)
-    */
   @Test
   def testBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(): Unit = {
+    verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(new Properties, expectTruncation = false)
+  }
+
+  @Test
+  def testBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdateIbp26(): Unit = {
+    val extraProps = new Properties
+    extraProps.put(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_6_IV0.version)
+    verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(extraProps, expectTruncation = true)
+  }
+
+  /**
+   * If a partition becomes a follower and the leader is unchanged it should check for truncation
+   * if the epoch has increased by more than one (which suggests it has missed an update). For
+   * IBP version 2.7 onwards, we don't require this since we can truncate at any time based
+   * on diverging epochs returned in fetch responses.
+   */
+  private def verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(extraProps: Properties,
+                                                                             expectTruncation: Boolean): Unit = {
     val topicPartition = 0
     val followerBrokerId = 0
     val leaderBrokerId = 1
@@ -800,7 +814,7 @@ class ReplicaManagerTest {
     // Prepare the mocked components for the test
     val (replicaManager, mockLogMgr) = prepareReplicaManagerAndLogManager(new MockTimer(time),
       topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, leaderBrokerId, countDownLatch,
-      expectTruncation = true, localLogOffset = Some(10))
+      expectTruncation = expectTruncation, localLogOffset = Some(10), extraProps = extraProps)
 
     // Initialize partition state to follower, with leader = 1, leaderEpoch = 1
     val tp = new TopicPartition(topic, topicPartition)
@@ -1528,7 +1542,9 @@ class ReplicaManagerTest {
               override def doWork() = {
                 // In case the thread starts before the partition is added by AbstractFetcherManager,
                 // add it here (it's a no-op if already added)
-                val initialOffset = OffsetAndEpoch(offset = 0L, leaderEpoch = leaderEpochInLeaderAndIsr)
+                val initialOffset = InitialFetchState(
+                  leader = new BrokerEndPoint(0, "localhost", 9092),
+                  initOffset = 0L, currentLeaderEpoch = leaderEpochInLeaderAndIsr)
                 addPartitions(Map(new TopicPartition(topic, topicPartition) -> initialOffset))
                 super.doWork()
 
diff --git a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala
new file mode 100644
index 0000000..2ad4776
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala
@@ -0,0 +1,29 @@
+/**
+  * Licensed to the Apache Software Foundation (ASF) under one or more
+  * contributor license agreements.  See the NOTICE file distributed with
+  * this work for additional information regarding copyright ownership.
+  * The ASF licenses this file to You under the Apache License, Version 2.0
+  * (the "License"); you may not use this file except in compliance with
+  * the License.  You may obtain a copy of the License at
+  *
+  * http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+
+package kafka.server.epoch
+
+import kafka.api.KAFKA_2_6_IV0
+
+/**
+ * With IBP 2.7 onwards, we truncate based on diverging epochs returned in fetch responses.
+ * EpochDrivenReplicationProtocolAcceptanceTest tests epochs with latest version. This test
+ * verifies that we handle older IBP versions with truncation on leader/follower change correctly.
+ */
+class EpochDrivenReplicationProtocolAcceptanceWithIbp26Test extends EpochDrivenReplicationProtocolAcceptanceTest {
+  override val apiVersion = KAFKA_2_6_IV0
+}
diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
index f8c528b..a587430 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
@@ -31,6 +31,8 @@ import org.apache.kafka.common.requests.{AbstractRequest, EpochEndOffset, FetchR
 import org.apache.kafka.common.utils.{SystemTime, Time}
 import org.apache.kafka.common.{Node, TopicPartition}
 
+import scala.collection.Map
+
 /**
   * Stub network client used for testing the ReplicaFetcher, wraps the MockClient used for consumer testing
   *
@@ -53,17 +55,22 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc
   var epochFetchCount = 0
   var lastUsedOffsetForLeaderEpochVersion = -1
   var callback: Option[() => Unit] = None
-  var currentOffsets: java.util.Map[TopicPartition, EpochEndOffset] = offsets
+  var currentOffsets: util.Map[TopicPartition, EpochEndOffset] = offsets
+  var fetchPartitionData: Map[TopicPartition, FetchResponse.PartitionData[Records]] = Map.empty
   private val sourceNode = new Node(sourceBroker.id, sourceBroker.host, sourceBroker.port)
 
   def setEpochRequestCallback(postEpochFunction: () => Unit): Unit = {
     callback = Some(postEpochFunction)
   }
 
-  def setOffsetsForNextResponse(newOffsets: java.util.Map[TopicPartition, EpochEndOffset]): Unit = {
+  def setOffsetsForNextResponse(newOffsets: util.Map[TopicPartition, EpochEndOffset]): Unit = {
     currentOffsets = newOffsets
   }
 
+  def setFetchPartitionDataForNextResponse(partitionData: Map[TopicPartition, FetchResponse.PartitionData[Records]]): Unit = {
+    fetchPartitionData = partitionData
+  }
+
   override def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = {
     if (!NetworkClientUtils.awaitReady(client, sourceNode, time, 500))
       throw new SocketTimeoutException(s"Failed to connect within 500 ms")
@@ -82,8 +89,11 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc
 
       case ApiKeys.FETCH =>
         fetchCount += 1
-        new FetchResponse(Errors.NONE, new java.util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData[Records]], 0,
-          JFetchMetadata.INVALID_SESSION_ID)
+        val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData[Records]]
+        fetchPartitionData.foreach { case (tp, data) => partitionData.put(tp, data) }
+        fetchPartitionData = Map.empty
+        new FetchResponse(Errors.NONE, partitionData, 0,
+          if (partitionData.isEmpty) JFetchMetadata.INVALID_SESSION_ID else 1)
 
       case _ =>
         throw new UnsupportedOperationException
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
index 6eeca1b..02bf2f6 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
@@ -31,6 +31,7 @@ import kafka.server.AlterIsrManager;
 import kafka.server.BrokerState;
 import kafka.server.BrokerTopicStats;
 import kafka.server.FailedPartitions;
+import kafka.server.InitialFetchState;
 import kafka.server.KafkaConfig;
 import kafka.server.LogDirFailureChannel;
 import kafka.server.MetadataCache;
@@ -137,7 +138,7 @@ public class ReplicaFetcherThreadBenchmark {
                 Time.SYSTEM);
 
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<BaseRecords>> initialFetched = new LinkedHashMap<>();
-        scala.collection.mutable.Map<TopicPartition, OffsetAndEpoch> offsetAndEpochs = new scala.collection.mutable.HashMap<>();
+        scala.collection.mutable.Map<TopicPartition, InitialFetchState> initialFetchStates = new scala.collection.mutable.HashMap<>();
         for (int i = 0; i < partitionCount; i++) {
             TopicPartition tp = new TopicPartition("topic", i);
 
@@ -162,7 +163,7 @@ public class ReplicaFetcherThreadBenchmark {
 
             partition.makeFollower(partitionState, offsetCheckpoints);
             pool.put(tp, partition);
-            offsetAndEpochs.put(tp, new OffsetAndEpoch(0, 0));
+            initialFetchStates.put(tp, new InitialFetchState(new BrokerEndPoint(3, "host", 3000), 0, 0));
             BaseRecords fetched = new BaseRecords() {
                 @Override
                 public int sizeInBytes() {
@@ -181,7 +182,7 @@ public class ReplicaFetcherThreadBenchmark {
         ReplicaManager replicaManager = Mockito.mock(ReplicaManager.class);
         Mockito.when(replicaManager.brokerTopicStats()).thenReturn(brokerTopicStats);
         fetcher = new ReplicaFetcherBenchThread(config, replicaManager, pool);
-        fetcher.addPartitions(offsetAndEpochs);
+        fetcher.addPartitions(initialFetchStates);
         // force a pass to move partitions to fetching state. We do this in the setup phase
         // so that we do not measure this time as part of the steady state work
         fetcher.doWork();


Mime
View raw message