DirectKafkaInputDStream源码分析(包含动态分区感知)

先上一个官网的栗子:

object DirectKafkaWordCount {
  def main(args: Array[String]) {
    if (args.length < 2) {
      System.err.println(s"""
        |Usage: DirectKafkaWordCount <brokers> <topics>
        |  <brokers> is a list of one or more Kafka brokers
        |  <topics> is a list of one or more kafka topics to consume from
        |
        """.stripMargin)
      System.exit(1)
    }

    StreamingExamples.setStreamingLogLevels()

    val Array(brokers, topics) = args

    // Create context with 2 second batch interval
    val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(2))

    // Create direct kafka stream with brokers and topics
    val topicsSet = topics.split(",").toSet
    val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
    val messages = KafkaUtils.createDirectStream[String, String](
      ssc,
      LocationStrategies.PreferConsistent,
      ConsumerStrategies.Subscribe[String, String](topicsSet, kafkaParams))

    // Get the lines, split them into words, count the words and print
    val lines = messages.map(_.value)
    val words = lines.flatMap(_.split(" "))
    val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
    wordCounts.print()

    // Start the computation
    ssc.start()
    ssc.awaitTermination()
  }
}

创建kafka数据源的RDD是通过KafkaUtils.createDirectStream来进行创建的,我们本次分析的就是这部分代码:

打开spark2.3.0的源码,进入的spark-2.3.0\external路径下,到org.apache.spark.streaming.kafka010包下,看见KafkaUtils,找到KafkaUtils.createDirectStream

def createDirectStream[K, V](
      ssc: StreamingContext,
      locationStrategy: LocationStrategy,
      consumerStrategy: ConsumerStrategy[K, V]
    ): InputDStream[ConsumerRecord[K, V]] = {
    val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf)
    createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc)
  }

我们传入的有三个参数,ssc的上下文环境,LocationStrategy是什么呢?请看这段代码

object LocationStrategies {
  /**
   *  :: Experimental ::
   * Use this only if your executors are on the same nodes as your Kafka brokers.
   */
  @Experimental
  def PreferBrokers: LocationStrategy =
    org.apache.spark.streaming.kafka010.PreferBrokers

  /**
   *  :: Experimental ::
   * Use this in most cases, it will consistently distribute partitions across all executors.
   */
  @Experimental
  def PreferConsistent: LocationStrategy =
    org.apache.spark.streaming.kafka010.PreferConsistent

  /**
   *  :: Experimental ::
   * Use this to place particular TopicPartitions on particular hosts if your load is uneven.
   * Any TopicPartition not specified in the map will use a consistent location.
   */
  @Experimental
  def PreferFixed(hostMap: collection.Map[TopicPartition, String]): LocationStrategy =
    new PreferFixed(new ju.HashMap[TopicPartition, String](hostMap.asJava))

  /**
   *  :: Experimental ::
   * Use this to place particular TopicPartitions on particular hosts if your load is uneven.
   * Any TopicPartition not specified in the map will use a consistent location.
   */
  @Experimental
  def PreferFixed(hostMap: ju.Map[TopicPartition, String]): LocationStrategy =
    new PreferFixed(hostMap)
}

LocationStrategy有三种策略,第一个就是PreferBrokers策略,源码注释是,' Use this only if your executors are on the same nodes as your Kafka broker',也就是说,当你的executors和kafka的broker是同样的时候,可以用这个策略,但是,这种情况很少

第二个就是最常用的,PreferConsistent策略,在所有的executors上分配分区的意思。

ConsumerStrategies.Subscribe就是提交的参数列表和kafka参数的处理。

这里createDirectStream的时候会new DefaultPerPartitionConfig(ssc.sparkContext.getConf),这个用处很简单,是用来对kafka进行限流的

private class DefaultPerPartitionConfig(conf: SparkConf)
    extends PerPartitionConfig {
  val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0)

  def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate
}

这里可以通过spark.streaming.kafka.maxRatePerPartition配置来对每个partition进入的数据进行限流,防止有些机器 由于负载太高,而导致kafka传输数据出问题。

接下来就是调用

def createDirectStream[K, V](
      ssc: StreamingContext,
      locationStrategy: LocationStrategy,
      consumerStrategy: ConsumerStrategy[K, V],
      perPartitionConfig: PerPartitionConfig
    ): InputDStream[ConsumerRecord[K, V]] = {
    new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig)
  }
  

new了一个DirectKafkaInputDStream,我们来看看他的源码

private[spark] class DirectKafkaInputDStream[K, V](
    _ssc: StreamingContext,
    locationStrategy: LocationStrategy,
    consumerStrategy: ConsumerStrategy[K, V],
    ppc: PerPartitionConfig
  ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets 
  
  

abstract class InputDStream[T: ClassTag](_ssc: StreamingContext)
  extends DStream[T](_ssc)  
  
  
  
abstract class DStream[T: ClassTag] (
@transient private[streaming] var ssc: StreamingContext
) extends Serializable with Logging {

validateAtInit()

// =======================================================================
// Methods that should be implemented by subclasses of DStream
// =======================================================================

/** Time interval after which the DStream generates an RDD */
def slideDuration: Duration

/** List of parent DStreams on which this DStream depends on */
def dependencies: List[DStream[_]]

/** Method that generates an RDD for the given time */
def compute(validTime: Time): Option[RDD[T]]

// =======================================================================
// Methods and fields available on all DStreams
// =======================================================================

// RDDs generated, marked as private[streaming] so that testsuites can access it
@transient
private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]()
...

发现它是继承自InputDStream,而发现它是继承自InputDStream继承自DStream,里面的compute的注释写到很清楚,compute是用来生成RDD的。

DirectKafkaInputDStream生成kafkaRDD就是通过DirectKafkaInputDStream下重写的compute来生成的。

 override def compute(validTime: Time): Option[KafkaRDD[K, V]] = {
    val untilOffsets = clamp(latestOffsets())
    val offsetRanges = untilOffsets.map { case (tp, uo) =>
      val fo = currentOffsets(tp)
      OffsetRange(tp.topic, tp.partition, fo, uo)
    }
    val useConsumerCache = context.conf.getBoolean("spark.streaming.kafka.consumer.cache.enabled",
      true)
    val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray,
      getPreferredHosts, useConsumerCache)

    // Report the record number and metadata of this batch interval to InputInfoTracker.
    val description = offsetRanges.filter { offsetRange =>
      // Don't display empty ranges.
      offsetRange.fromOffset != offsetRange.untilOffset
    }.map { offsetRange =>
      s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
        s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}"
    }.mkString("\n")
    // Copy offsetRanges to immutable.List to prevent from being modified by the user
    val metadata = Map(
      "offsets" -> offsetRanges.toList,
      StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
    val inputInfo = StreamInputInfo(id, rdd.count, metadata)
    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

    currentOffsets = untilOffsets
    commitAll()
    Some(rdd)
  }

传入的time是经过判断的,判断time是否合法,判断代码如下:

/**
   * Checks whether the 'time' is valid wrt slideDuration for generating RDD.
   * Additionally it also ensures valid times are in strictly increasing order.
   * This ensures that InputDStream.compute() is called strictly on increasing
   * times.
   */
  override private[streaming] def isTimeValid(time: Time): Boolean = {
    if (!super.isTimeValid(time)) {
      false // Time not valid
    } else {
      // Time is valid, but check it is more than lastValidTime
      if (lastValidTime != null && time < lastValidTime) {
        logWarning(s"isTimeValid called with $time whereas the last valid time " +
          s"is $lastValidTime")
      }
      lastValidTime = time
      true
    }
  }
  
  
  /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */
  private[streaming] def isTimeValid(time: Time): Boolean = {
    if (!isInitialized) {
      throw new SparkException (this + " has not been initialized")
    } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
      logInfo(s"Time $time is invalid as zeroTime is $zeroTime" +
        s" , slideDuration is $slideDuration and difference is ${time - zeroTime}")
      false
    } else {
      logDebug(s"Time $time is valid")
      true
    }
  }
  
  
  def isMultipleOf(that: Duration): Boolean =
    (this.millis % that.millis == 0)

zeroTime就是计算机开始的时间,time-zeroTime取余slideDuration为0,则合法。

继续回到compute。

第一行代码就是val untilOffsets = clamp(latestOffsets()) 我们来看看latestOffsets()方法

 /**
   * Returns the latest (highest) available offsets, taking new partitions into account.
   */
  protected def latestOffsets(): Map[TopicPartition, Long] = {
    val c = consumer
    paranoidPoll(c)
    val parts = c.assignment().asScala

    // make sure new partitions are reflected in currentOffsets
    val newPartitions = parts.diff(currentOffsets.keySet)
    // position for new partitions determined by auto.offset.reset if no commit
    currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap
    // don't want to consume messages, so pause
    c.pause(newPartitions.asJava)
    // find latest available offsets
    c.seekToEnd(currentOffsets.keySet.asJava)
    parts.map(tp => tp -> c.position(tp)).toMap
  }

*******************************************************
protected var currentOffsets = Map[TopicPartition, Long]()

def consumer(): Consumer[K, V] = this.synchronized {
    if (null == kc) {
      kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava)
    }
    kc
  }
  
********************************************* 
def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = {
    val consumer = new KafkaConsumer[K, V](kafkaParams)
    consumer.subscribe(topics)
    val toSeek = if (currentOffsets.isEmpty) {
      offsets
    } else {
      currentOffsets
    }
    if (!toSeek.isEmpty) {
      // work around KAFKA-3370 when reset is none
      // poll will throw if no position, i.e. auto offset reset none and no explicit position
      // but cant seek to a position before poll, because poll is what gets subscription partitions
      // So, poll, suppress the first exception, then seek
      val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)
      val shouldSuppress =
        aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE"
      try {
        consumer.poll(0)
      } catch {
        case x: NoOffsetForPartitionException if shouldSuppress =>
          logWarning("Catching NoOffsetForPartitionException since " +
            ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " is none.  See KAFKA-3370")
      }
      toSeek.asScala.foreach { case (topicPartition, offset) =>
          consumer.seek(topicPartition, offset)
      }
      // we've called poll, we must pause or next poll may consume messages and set position
      consumer.pause(consumer.assignment())
    }

    consumer
  }  

latestOffsets方法第一个就是consumer方法,是用来初始化配置文件的,因为栗子中ConsumerStrategies.Subscribe[String, String](topicsSet, kafkaParams))的缘故,所以Subscribe中的onStart会被调用,当中的offset如果没有的话,注释也说了'the committed offset (if applicable) or kafka param auto.offset.reset will be used.',会根据你的配置来设置你当前的offset是什么情况。

实际上,consumer()也就是让你后去最新的offset。

然后调用paranoidPoll(c)方法来获取数据

/**
   * The concern here is that poll might consume messages despite being paused,
   * which would throw off consumer position.  Fix position if this happens.
   */
  private def paranoidPoll(c: Consumer[K, V]): Unit = {
    val msgs = c.poll(0)
    if (!msgs.isEmpty) {
      // position should be minimum offset per topicpartition
      msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) =>
        val tp = new TopicPartition(m.topic, m.partition)
        val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset)
        acc + (tp -> off)
      }.foreach { case (tp, off) =>
          logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate")
          c.seek(tp, off)
      }
    }
  }

用Consumer.poll取出数据,再将指针指向最新的offset(c.seek(tp, off))。

接下来就是kafka10很重要的一个动态分区感知的一个源代码

val parts = c.assignment().asScala

    // make sure new partitions are reflected in currentOffsets
    val newPartitions = parts.diff(currentOffsets.keySet)
    // position for new partitions determined by auto.offset.reset if no commit
    currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap
    // don't want to consume messages, so pause
    c.pause(newPartitions.asJava)
    // find latest available offsets
    c.seekToEnd(currentOffsets.keySet.asJava)
    parts.map(tp => tp -> c.position(tp)).toMap

先把Consumer中的分区信息拿出来,和之前的currentOffsets.keySet(topicpartition,offset)对比,如果发现多出来的分区,就设为newPartitions,再将新多出来的分区加入到保存的currentOffsets 中去,实现了分区的动态感知。

接下来的clamp方法

// limits the maximum number of messages per partition
  protected def clamp(
    offsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {

    maxMessagesPerPartition(offsets).map { mmp =>
      mmp.map { case (tp, messages) =>
          val uo = offsets(tp)
          tp -> Math.min(currentOffsets(tp) + messages, uo)
      }
    }.getOrElse(offsets)
  }

实际上就是对每个partition速率的一个限制。

接下来的

val offsetRanges = untilOffsets.map { case (tp, uo) =>
      val fo = currentOffsets(tp)
      OffsetRange(tp.topic, tp.partition, fo, uo)
    }
    
val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray,
      getPreferredHosts, useConsumerCache)    

实际上,我们在使用官网提供的exactly-once的时候,会用到offsetRanges,里面都是最新的topic,partition和偏移量,再通过new kafkaRDD来生成RDD

// Report the record number and metadata of this batch interval to InputInfoTracker.
    val description = offsetRanges.filter { offsetRange =>
      // Don't display empty ranges.
      offsetRange.fromOffset != offsetRange.untilOffset
    }.map { offsetRange =>
      s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
        s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}"
    }.mkString("\n")
    // Copy offsetRanges to immutable.List to prevent from being modified by the user
    val metadata = Map(
      "offsets" -> offsetRanges.toList,
      StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
    val inputInfo = StreamInputInfo(id, rdd.count, metadata)
    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

这是一个类似于log的东西,返回我们的消费相关信息

currentOffsets = untilOffsets
commitAll()
Some(rdd)


protected def commitAll(): Unit = {
    val m = new ju.HashMap[TopicPartition, OffsetAndMetadata]()
    var osr = commitQueue.poll()
    while (null != osr) {
      val tp = osr.topicPartition
      val x = m.get(tp)
      val offset = if (null == x) { osr.untilOffset } else { Math.max(x.offset, osr.untilOffset) }
      m.put(tp, new OffsetAndMetadata(offset))
      osr = commitQueue.poll()
    }
    if (!m.isEmpty) {
      consumer.commitAsync(m, commitCallback.get)
    }
  }

最后,将元数据currentOffsets更新,提交偏移量,返回一个我们需要的DStream。

分析结束。

阅读更多
换一批

没有更多推荐了,返回首页