spark-shuffle的读数据源码分析

对指定分区进行计算的抽象接口,以为CoGroupedRDD(或者ShuffleRDD,可能compute细节不同,但是shuffle读取的类或方法的调用时一样的)的compute方法为实现,源码:

override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = {val split = s.asInstanceOf[CoGroupPartition]
  val numRdds = dependencies.length
  val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
  for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
    case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked =>
      val dependencyPartition = split.narrowDeps(depNum).get.split
      //读取父RDD的数据
      val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
      rddIterators += ((it, depNum))
    case shuffleDependency: ShuffleDependency[_, _, _] =>
      // 首先从SparkEnv获取ShuffleManager,然后从ShuffleDependency中获取注册到ShuffleManager时得到的shuffleHandle,根据shuffleHandle和当前Task对应的分区ID获取ShuffleWriter,根据获取的ShuffleReader调用read接口,读取Shuffle的Map输出
      val it = SparkEnv.get.shuffleManager
        .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
        .read()
      rddIterators += ((it, depNum))
  }
  val map = createExternalMap(numRdds)
  for ((it, depNum) <- rddIterators) {
    map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
  }
  context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
  context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
  context.internalMetricsToAccumulators(
    InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
  new InterruptibleIterator(context,
    map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}

从源码可知,带宽依赖的RDD的compute操作,最终通过SparkEnv的ShuffleManager实例的getReader方法获取数据读取器,然后再调用读取器的read方法读取指定分区范围的Shuffle数据。

特质ShuffleReader是由子类BlockStoreShuffleReader实现,其中BlockStoreShuffleReader的read方法的源码:

/** 为该Reduce任务读取并合并key-values值 */
override def read(): Iterator[Product2[K, C]] = {
  val blockFetcherItr = new ShuffleBlockFetcherIterator(context,  blockManager.shuffleClient,  blockManager,
/** 当ShuffleMapTask完成后注册到mapOutputTracker的元数据信息,会通过mapOutputTracker获取,同时指定获取的分区返回*/
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
/**默认是48M,并行读取策略:避免目标机器占用过多带宽,也可以启动并行机制加快读取速度*/
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
  //针对前面获取的各数据块的唯一标识ID信息及其对应的输入流进行处理
  val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
    blockManager.wrapForCompression(blockId, inputStream)  // lz4、lzf、snappy三种压缩器
  }
  val ser = Serializer.getSerializer(dep.serializer)
  val serializerInstance = ser.newInstance()
  // 为每个stream创建一个key-values迭代器
  val recordIter = wrappedStreams.flatMap { wrappedStream =>
    serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
  }
  // 更新上下文任务量
  val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
  val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
    recordIter.map(record => {
      readMetrics.incRecordsRead(1)
      record
    }),
    context.taskMetrics().updateShuffleReadMetrics())
  // 为了支持任务取消,必须使用可中断迭代器
  val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
//读取的数据进行聚合处理
  val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    if (dep.mapSideCombine) { //获取的数据在Map端进行聚合处理
      val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
//Map端各分区针对key进行合并后的结果再次聚合,Map的合并可以大大减少网络传输
      dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
    } else { //只需要在Reduce端进行聚合
      val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
      dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
    }
  } else {  //不需要聚合
    require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
    interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
  }

  // 在基于Sort的Shuffle实现过程中,默认基于PartitionId进行排序,在分区的内部数据是没有排序的,因此添加了keyOrdering变量,提供是否需要针对分区内的数据进行排序的标识信息,若定义了排序,则对输出结果进行排序
  dep.keyOrdering match {  //判断是否需要排序
    case Some(keyOrd: Ordering[K]) =>
      // 为了减少内存压力,避免GC开销,引入了外部排序器对数据进行排序。当内存不足以容纳排序的数据量时,会根据配置的spark.shuffle.spill属性来决定是否需要spill到磁盘中,默认是打开spill开关的。
      val sorter =
        new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
      sorter.insertAll(aggregatedIter)
      context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
      context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
      context.internalMetricsToAccumulators(
        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
      CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
    case None =>
      aggregatedIter //若不需要排序分区则直接返回
  }
}

在BlockStoreShuffleReader的read方法调用ShuffleBlockFetcherIterator构造器,实现ShuffleBlockFetcherIterator.initialize方法,在initialize方法先后实现splitLocalRemoteBlocks、fetchUpToMaxBytes和fetchLocalBlocks等方法,首先来分析ShuffleBlockFetcherIterator的splitLocalRemoteBlocks方法的源码:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
  // 每次最多同时并行的启动5个线程从5个节点上读取数据,每次请求的容量<= spark.reducer.maxMbInFlight(默认是48M)/5
  val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
  val remoteRequests = new ArrayBuffer[FetchRequest]
  var totalBlocks = 0
  for ((address, blockInfos) <- blocksByAddress) {
    totalBlocks += blockInfos.size
    if (address.executorId == blockManager.blockManagerId.executorId) { //获取本地的数据块
      localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) //过滤数据块为空的,当数据与BlockManager在同一个节点,则直接把Blocks存入localBlocks中
      numBlocksToFetch += localBlocks.size
    } else { //数据不在本地
      val iterator = blockInfos.iterator
      var curRequestSize = 0L
      var curBlocks = new ArrayBuffer[(BlockId, Long)]
      while (iterator.hasNext) { //BlockId的格式:shuffle_+shuffleId_+mapId_+reduceId
        val (blockId, size) = iterator.next()
        if (size > 0) { //过滤为空的数据块
          curBlocks += ((blockId, size))
          remoteBlocks += blockId //记录远程机器上的数据块Id(BlockId)
          numBlocksToFetch += 1
          curRequestSize += size
        } else if (size < 0) {
          throw new BlockException(blockId, "Negative block size " + size)
        } 
        if (curRequestSize >= targetRequestSize) {
          remoteRequests += new FetchRequest(address, curBlocks)
          curBlocks = new ArrayBuffer[(BlockId, Long)]
          curRequestSize = 0
        }
      } //当数据不在本地时,生成remoteRequests,其条件:curReuestSize大等于maxBytesInFlight/5,会把block信息存入remoteRequests中,包括block位置,blockId,block大小信息
      if (curBlocks.nonEmpty) {
        remoteRequests += new FetchRequest(address, curBlocks)
      }
    } //注意:FetchRequest可能会有内存泄漏,若单个Block过大,fetch过来占用内存过大造成OOM
  }
  remoteRequests
}

ShuffleBlockFetcherIterator的fetchUpToMaxBytes方法是发送请求获取远程的数据,只有到当前的数据量与请求数据量之和小于maxBytesInFlight时才能发送请求:

private def fetchUpToMaxBytes(): Unit = {
  while (fetchRequests.nonEmpty &&
    (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
    sendRequest(fetchRequests.dequeue())
  }
}

 通过实现fetchUpToMaxBytes方法获取完远程数据后,以方法fetchLocalBlocks方法获取本地数据,ShuffleBlockFetcherIterator的fetchLocalBlocks的源码:

private[this] def fetchLocalBlocks() {
  val iter = localBlocks.iterator
  while (iter.hasNext) {
    val blockId = iter.next()
    try {
      val buf = blockManager.getBlockData(blockId)
      shuffleMetrics.incLocalBlocksFetched(1)
      shuffleMetrics.incLocalBytesRead(buf.size)
      buf.retain()
      results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
    } catch {
      case e: Exception =>
        results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
        return
    }
  }
}

fetcheLocalBlocks方法获取本地数据块其实是调用BlockManager的getBlockData方法,BlockManager的getBlockData方法真正调用的IndexShuffleBlockResolver或FileShuffleBlockResolver(两类继承特质ShuffleBlockResolver)的getBlockData:

ShuffleBlockFetcherIterator.fetchLocalBlocks -> BlockManager.getBlockData -> ShuffleBlockResolver.getBlockData流程。

IndexShuffleBlockResolver的getBlockData实现的DiskBlockManager的getFile方法;FileShuffleBlockResolver的getBlockData实现的是FileSegmentManagedBuffer构造函数。

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
  // 每次最多同时并行的启动5个线程从5个节点上读取数据,每次请求的容量<= spark.reducer.maxMbInFlight(默认是48M)/5
  val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
  val remotRequests = new ArrayBuffer[FetchRequest]
  var totalBlocks = 0
  for ((address, blockInfos) <- blocksByAddress) {
    totalBlocks += blockInfos.size
    if (address.executorId == blockManager.blockManagerId.executorId) { //获取本地的数据块
      localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) //过滤数据块为空的,当数据与BlockManager在同一个节点,则直接把Blocks存入localBlocks中
      numBlocksToFetch += localBlocks.size
    } else { //数据不在本地
      val iterator = blockInfos.iterator
      var curRequestSize = 0L
      var curBlocks = new ArrayBuffer[(BlockId, Long)]
      while (iterator.hasNext) { //BlockId的格式:shuffle_+shuffleId_+mapId_+reduceId
        val (blockId, size) = iterator.next()
        if (size > 0) { //过滤为空的数据块
          curBlocks += ((blockId, size))
          remoteBlocks += blockId //记录远程机器上的数据块Id(BlockId)
          numBlocksToFetch += 1
          curRequestSize += size
        } else if (size < 0) {
          throw new BlockException(blockId, "Negative block size " + size)
        } 
        if (curRequestSize >= targetRequestSize) {
          remoteRequests += new FetchRequest(address, curBlocks)
          curBlocks = new ArrayBuffer[(BlockId, Long)]
          curRequestSize = 0
        }
      } //当数据不在本地时,生成remoteRequests,其条件:curReuestSize大等于maxBytesInFlight/5,会把block信息存入remoteRequests中,包括block位置,blockId,block大小信息
      if (curBlocks.nonEmpty) {
        remoteRequests += new FetchRequest(address, curBlocks)
      }
    } //注意:FetchRequest可能会有内存泄漏,若单个Block过大,fetch过来占用内存过大造成OOM
  }
  remoteRequests
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值