Spark-ShuffleReader

一、上下文

Spark-ShuffleManager》中讲了ShuffleManager的大致轮廓,接下来我们来对ShuffleReader展开详细分析

二、什么时候发生的?

当ShuffleWriter准备写入时,这个Stage的最后一个RDD开始准备数据,调用自己的iterator方法,
如果这个RDD不是持久化RDD或者Checkpoint,机会调用自己的compute方法,compute方法会调用自己的方法,其中会调用父类的iterator方法。由于整个Stage的RDD都是窄依赖关系,利用迭代器管道模式一直调用到这个Stage最开始的那个RDD(它是一个ShuffledRDD)的compute方法,
这里就到了这个Stage的尽头,它必须获取前一个Stage的数据并封装一个iterator供后续的RDD使用。下面我们就来看下ShuffledRDD的compute方法

class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
    @transient var prev: RDD[_ <: Product2[K, V]],
    part: Partitioner)
  extends RDD[(K, C)](prev.context, Nil) {

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    val metrics = context.taskMetrics().createTempShuffleReadMetrics()
    //这里会返回唯一的
    SparkEnv.get.shuffleManager.getReader(
      dep.shuffleHandle, split.index, split.index + 1, context, metrics)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

}

三、获取ShuffleReader

private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {

  override def getReader[K, C](
      handle: ShuffleHandle,
      startMapIndex: Int,
      endMapIndex: Int,
      startPartition: Int,
      endPartition: Int,
      context: TaskContext,
      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
    val (blocksByAddress, canEnableBatchFetch) =
      //如果Push-based shuffle开启且 rdd不是Barrier 就满足if判断
      if (baseShuffleHandle.dependency.shuffleMergeEnabled) {
        //返回一个MapSizesByExecutorId  它是一个 tuple 2
        //(iter,是否开启批量拉取)
        //iter 是一个 Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
        //BlockManagerId 表示BlockManager的唯一标识符 有 executorId、host、端口
        //BlockId : shuffle 数据块标识  
        //Long : shuffle block size
        //Int : map index
        val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId(
          handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
        //是否开启批量拉取
        (res.iter, res.enableBatchFetch)
      } else {
        //从mapOutputTracker获取需要从上一个Stage的哪些节点拉取数据
        val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
          handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
        (address, true)
      }
    //通过从其他节点的块存储中请求块,从shuffle中获取和读取块。
    new BlockStoreShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
      shouldBatchFetch =
        canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context))
  }

}

四、用ShuffleReader拉取数据

1、BlockStoreShuffleReader

private[spark] class BlockStoreShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
    context: TaskContext,
    readMetrics: ShuffleReadMetricsReporter,
    serializerManager: SerializerManager = SparkEnv.get.serializerManager,
    blockManager: BlockManager = SparkEnv.get.blockManager,
    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
    shouldBatchFetch: Boolean = false)
  extends ShuffleReader[K, C] with Logging {

  override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      blockManager.blockStoreClient,
      blockManager,
      mapOutputTracker,
      blocksByAddress,
      serializerManager.wrapStream,
      // spark.reducer.maxSizeInFlight 默认值 48 
      //每个reduce任务同时获取的map输出的最大大小,单位为MiB。
      //由于每个输出都需要我们创建一个缓冲区来接收它,这表示每个reduce任务的固定内存开销,因此除非您有大量内存,否则请保持较小的内存开销
      SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
      //spark.reducer.maxReqsInFlight 默认值值 Int.MaxValue
      //此配置限制了在任何给定点获取块的远程请求数量。当集群中的主机数量增加时,可能会导致与一个或多个节点的大量入站连接,导致Worker进程在负载下失败。通过允许它限制获取请求的数量,可以缓解这种情况
      SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
      //spark.reducer.maxBlocksInFlightPerAddress 默认值值 Int.MaxValue
      //此配置限制了每个reduce任务从给定主机端口获取的远程块的数量。当在单次或同时从给定地址请求大量块时,这可能会使executor或Node Manager崩溃。这对于在启用外部shuffle时减少节点管理器的负载特别有用。您可以通过将其设置为较低的值来缓解这个问题。
      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
      //spark.network.maxRemoteBlockSizeFetchToMem  默认值 200m
      //当块的大小超过此阈值(以字节为单位)时,远程块将被提取到磁盘。这是为了避免一个巨大的请求占用太多的内存。请注意,此配置将影响shuffle获取和block manager远程块获取。对于启用了外部shuffle服务的用户,此功能只有在外部shuffles服务至少为2.3.0时才能工作。
      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
      //spark.shuffle.maxAttemptsOnNettyOOM 默认值 10
      //shuffle块的最大尝试次数将在Netty OOM问题上重试,然后抛出shuffle获取失败
      SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
      //spark.shuffle.detectCorrupt  默认值 true
      //是否检测到所取块中的任何损坏
      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
      //spark.shuffle.detectCorrupt.useExtraMemory  默认值 false
      //如果启用,压缩/加密流的一部分将通过使用额外的内存来检测早期损坏,从而进行解压缩/解密。抛出的任何IOException都将导致任务重试一次,如果任务再次失败并出现相同的异常,则将抛出FetchFailedException以重试上一阶段
      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
      //spark.shuffle.checksum.enabled  默认值 true
      //是否计算洗牌数据的校验和。如果启用,Spark将计算映射输出文件中每个分区数据的校验和值,并将这些值存储在磁盘上的校验和文件中。当检测到shuffle数据损坏时,Spark将尝试使用校验和文件诊断损坏的原因(例如网络问题、磁盘问题等)。
      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
      //spark.shuffle.checksum.algorithm 默认值 ADLER32
      //该算法用于计算洗牌校验和。目前,它只支持JDK的内置算法
      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
      readMetrics,
      //判断是否可以批量拉取
      fetchContinuousBlocksInBatch).toCompletionIterator

    val serializerInstance = dep.serializer.newInstance()

    // 为每一个 stream 创建一个 key/value iterator 
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 更新每条读取记录的上下文任务指标。
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // 这里必须使用可中断迭代器来支持任务取消
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    //判断依赖关系是否启用了聚合
    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      //判断map端是否启用了聚合,如果不聚合就在reduce端做聚合
      if (dep.mapSideCombine) {
        // 我们正在读取已经组合的值
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // 我们不知道值类型,但也不在乎——依赖关系*应该*确保它与此聚合器兼容,这将把值类型转换为组合类型C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 如果定义了排序顺序,则对输出进行排序。
    val resultIter: Iterator[Product2[K, C]] = dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // 创建一个 ExternalSorter 去排序数据
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAllAndUpdateMetrics(aggregatedIter)
      case None =>
        aggregatedIter
    }

    resultIter match {
      case _: InterruptibleIterator[Product2[K, C]] => resultIter
      case _ =>
        // 在这里使用另一个可中断迭代器来支持任务取消,因为聚合器或(和)排序器可能已经消耗了之前的可中断迭代器。
        new InterruptibleIterator[Product2[K, C]](context, resultIter)
    }
  }


}

2、ShuffleBlockFetcherIterator

获取多个块的迭代器。对于本地块,它从本地块管理器获取。对于远程块,它使用提供的BlockTransferService获取它们。

这将创建一个(BlockID,InputStream)元组的迭代器,以便调用者可以在接收块时以流水线方式处理块。

该实现限制了远程获取,使其不超过maxBytesInFlight,以避免使用太多内存。

private[spark]
final class ShuffleBlockFetcherIterator(
    context: TaskContext, //用于度量系统更新
    shuffleClient: BlockStoreClient,//用于拉取远端块
    blockManager: BlockManager, //用于读取本地块
    mapOutputTracker: MapOutputTracker,//用于在启用push-based shuffle时,如果我们无法获取shuffle块,则回退到获取原始块。
    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], //按[[BlockManagerId]]分组的要获取的块列表。对于每个区块,我们还需要两个信息:1。大小(以字节为单位,作为长字段),以限制内存使用;2.此块的mapIndex,表示映射阶段的索引。请注意,零大小的块已经被排除在外,这发生在[[org.apache.spark.MapOutputTracker.convertMapStatuses]]中。
    streamWrapper: (BlockId, InputStream) => InputStream,//一个用于包装返回的输入流的函数。
    maxBytesInFlight: Long, //在任何给定点要获取的远程块的最大大小(以字节为单位)。
    maxReqsInFlight: Int, //在任何给定点获取块的最大远程请求数。
    maxBlocksInFlightPerAddress: Int, //给定远程主机端口在任何给定点获取的最大洗牌块数。
    val maxReqSizeShuffleToMem: Long, //可以shuffle到内存的请求的最大大小(以字节为单位)。
    maxAttemptsOnNettyOOM: Int, //在抛出获取失败之前,由于Netty OOM,可以重试的最大块数。
    detectCorrupt: Boolean,  //是否检测到所取块中的任何损坏。
    detectCorruptUseExtraMemory: Boolean,
    checksumEnabled: Boolean, //是否启用了洗牌校验和。启用后,Spark将尝试诊断块损坏的原因。
    checksumAlgorithm: String, //在计算块数据的校验和值时使用的校验和算法。
    shuffleMetrics: ShuffleReadMetricsReporter, //用于报告洗牌指标。
    doBatchFetch: Boolean) //如果服务器端支持,则从同一executor批量获取连续的shuffle块。
  extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {

  //远程请求的长度最多为maxBytesInFlight/5;保持它们小于maxBytesInFlight的原因是允许从最多5个节点进行多个并行提取,而不是阻止从一个节点读取输出。
  private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)

  //承载要获取的本地块,不包括零大小的块
  private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()

  //结果数据队列
  //这将[[org.apache.spark.network.BlockTransferService]]提供的异步模型转换为同步模型(迭代器)。
  private[this] val results = new LinkedBlockingQueue[FetchResult]

  //当前[[FetchResult]]正在处理中。我们对此进行跟踪,以便在处理当前缓冲区时发生运行时异常时释放当前缓冲区。
  @volatile private[this] var currentResult: SuccessFetchResult = null

  //要发出的获取请求队列;我们将逐步取消请求,以确保正在传输的字节数限制在maxBytesInFlight。
  private[this] val fetchRequests = new Queue[FetchRequest]

  initialize()

  private[this] def initialize(): Unit = {
    // 在清理中添加任务完成回调(在成功和失败的情况下都调用)
    context.addTaskCompletionListener(onCompleteCallback)
    // 要获取的本地块,不包括零大小的块。
    val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
    val hostLocalBlocksByExecutor =
      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
    val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
    //按不同的获取模式对块进行分区:本地、主机本地、推送合并的本地和远程块。
    val remoteRequests = partitionBlocksByFetchMode(
      blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks)
    //以随机顺序将远程请求添加到我们的队列中 为了不使网络拥堵在同一台节点
    fetchRequests ++= Utils.randomize(remoteRequests)
    assert ((0 == reqsInFlight) == (0 == bytesInFlight),
      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

    // 发送最初的区块请求,最高可达 maxBytesInFlight
    fetchUpToMaxBytes()

    val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
    val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest
    logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" +
      (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else ""))

    //获取本地块
    fetchLocalBlocks(localBlocks)
    logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
    //获取host本地块
    fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
    pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
  }

  /**
   * 这是从初始化和从[[PushBasedFetchHelper]]触发的回退中调用的
   */
  private[this] def partitionBlocksByFetchMode(
      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
    logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
      + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")

    //分区到本地、主机本地、推合并本地、远程(包括推合并远程)块。远程块进一步拆分为大小最多为maxBytesInFlight的FetchRequests,以限制传输中的数据量
    val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
    var localBlockBytes = 0L
    var hostLocalBlockBytes = 0L
    var numHostLocalBlocks = 0
    var pushMergedLocalBlockBytes = 0L
    val prevNumBlocksToFetch = numBlocksToFetch

    val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
    val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
    for ((address, blockInfos) <- blocksByAddress) {
      checkBlockSizes(blockInfos)
      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
        //这些是推合并块或洗牌块
        if (address.host == blockManager.blockManagerId.host) {
          numBlocksToFetch += blockInfos.size
          pushMergedLocalBlocks ++= blockInfos.map(_._1)
          pushMergedLocalBlockBytes += blockInfos.map(_._2).sum
        } else {
          //远程拉取 也就是存在开启推送且没有推送的情况
          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
        }
      } else if (localExecIds.contains(address.executorId)) {//本executor
        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
        numBlocksToFetch += mergedBlockInfos.size
        localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
        localBlockBytes += mergedBlockInfos.map(_.size).sum
      } else if (blockManager.hostLocalDirManager.isDefined &&
        address.host == blockManager.blockManagerId.host) {//本主机
        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
        numBlocksToFetch += mergedBlockInfos.size
        val blocksForAddress =
          mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
        hostLocalBlocksByExecutor += address -> blocksForAddress
        numHostLocalBlocks += blocksForAddress.size
        hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
      } else {
        //需要远程拉取    
        val (_, timeCost) = Utils.timeTakenMs[Unit] {
          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
        }
        logDebug(s"Collected remote fetch requests for $address in $timeCost ms")
      }
    }
    val (remoteBlockBytes, numRemoteBlocks) =
      collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
      pushMergedLocalBlockBytes
    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
    
    this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
      .flatMap { infos => infos.map(info => (info._1, info._3)) }
    collectedRemoteRequests
  }

  private def fetchUpToMaxBytes(): Unit = {
    if (isNettyOOMOnShuffle.get()) {
      if (reqsInFlight > 0) {
        //如果Netty仍然OOMed并且有正在进行的获取请求,请立即返回
        return
      } else {
        resetNettyOOMFlagIfPossible(0)
      }
    }

    // 发送最多maxBytesInFlight的取件请求。如果您无法立即从远程主机获取,请将请求推迟到下次可以处理时。

    // 如果可能的话,处理任何未完成的延迟获取请求。
    if (deferredFetchRequests.nonEmpty) {
      for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
        while (isRemoteBlockFetchable(defReqQueue) &&
            !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
          val request = defReqQueue.dequeue()
          logDebug(s"Processing deferred fetch request for $remoteAddress with "
            + s"${request.blocks.length} blocks")
          send(remoteAddress, request)
          if (defReqQueue.isEmpty) {
            deferredFetchRequests -= remoteAddress
          }
        }
      }
    }

    // 如果可能的话,处理任何常规的获取请求。
    while (isRemoteBlockFetchable(fetchRequests)) {
      val request = fetchRequests.dequeue()
      val remoteAddress = request.address
      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
        logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
        val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
        defReqQueue.enqueue(request)
        deferredFetchRequests(remoteAddress) = defReqQueue
      } else {
        send(remoteAddress, request)
      }
    }

    //最终都要走这个方法拉取数据
    def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
      if (request.forMergedMetas) {
        pushBasedFetchHelper.sendFetchMergedStatusRequest(request)
      } else {
        sendRequest(request)
      }
      numBlocksInFlightPerAddress(remoteAddress) =
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
    }

    def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
      fetchReqQueue.nonEmpty &&
        (bytesInFlight == 0 ||
          (reqsInFlight + 1 <= maxReqsInFlight &&
            bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
    }

    // 检查发送新的获取请求是否会超过从给定远程地址获取的最大块数。
    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
      numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
        maxBlocksInFlightPerAddress
    }
  }


  //为某些块向远端节点发送请求
  private[this] def sendRequest(req: FetchRequest): Unit = {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // 这样我们就可以查找每个blockID的 block信息
    val infoMap = req.blocks.map {
      case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
    }.toMap
    val remainingBlocks = new HashSet[String]() ++= infoMap.keys
    val deferredBlocks = new ArrayBuffer[String]()
    val blockIds = req.blocks.map(_.blockId.toString)
    val address = req.address

    @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
      if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
        val blocks = deferredBlocks.map { blockId =>
          val (size, mapIndex) = infoMap(blockId)
          FetchBlockInfo(BlockId(blockId), size, mapIndex)
        }
        //results还放请求拉取信息
        results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq)))
        deferredBlocks.clear()
      }
    }

    //监听器
    val blockFetchingListener = new BlockFetchingListener {
      //拉取成功之后的操作
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
        // 只有当迭代器不是僵尸时,即cleanup() 尚未被调用,才将缓冲区添加到结果队列中。
        ShuffleBlockFetcherIterator.this.synchronized {
          if (!isZombie) {
            // 增加引用计数,因为我们需要将其传递给另一个线程。这需要在使用后释放。
            buf.retain()
            remainingBlocks -= blockId
            blockOOMRetryCounts.remove(blockId)
            //把成功拉取的块放results中
            results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
              address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
            logDebug("remainingBlocks: " + remainingBlocks)
            enqueueDeferredFetchRequestIfNecessary()
          }
        }
        logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}")
      }

      //拉取失败的操作
      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
        ShuffleBlockFetcherIterator.this.synchronized {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          e match {
            //SPARK-27991:捕获Netty OOM,并尽早将标记“isNettyOOMOnShuffle”(在任务之间共享)设置为true。在以下情况下,挂起的获取请求将不会在之后发送:1)Netty空闲内存>=maxReqSizeShuffleToMem-我们将在获取请求成功时检查这一点。2) 正在处理的请求数量变为0——每当调用`fetchUpToMaxBytes`时,我们都会在其中进行检查。尽管Netty内存在多个模块之间共享,例如shuffle、rpc,但出于实现简单性的考虑,该标志仅对shuffle生效。我们将缓冲OOM错误导致的连续块失败,直到当前请求中没有剩余块为止。然后,我们将把这些块打包成一个相同的获取请求,以便稍后重试。这样,它将有助于减少远程服务器的并发连接和数据负载压力,而不是为每个块创建获取请求。请注意,捕获OOM并基于它做一些事情只是处理Netty OOM问题的一种变通方法,这不是实现内存管理的最佳方式。当我们找到一种精确管理内蒂记忆的方法时,我们可以摆脱它。
            case _: OutOfDirectMemoryError
                if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM =>
              if (!isZombie) {
                val failureTimes = blockOOMRetryCounts(blockId)
                blockOOMRetryCounts(blockId) += 1
                if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
                  // The fetcher can fail remaining blocks in batch for the same error. So we only
                  // log the warning once to avoid flooding the logs.
                  logInfo(s"Block $blockId has failed $failureTimes times " +
                    s"due to Netty OOM, will retry")
                }
                remainingBlocks -= blockId
                deferredBlocks += blockId
                enqueueDeferredFetchRequestIfNecessary()
              }

            case _ =>
              val block = BlockId(blockId)
              if (block.isShuffleChunk) {
                remainingBlocks -= blockId
                //results还会放失败的块信息
                results.put(FallbackOnPushMergedFailureResult(
                  block, address, infoMap(blockId)._1, remainingBlocks.isEmpty))
              } else {
                results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
              }
          }
        }
      }
    }

    // 当请求太大时,将远程shuffle块提取到磁盘。由于shuffle数据已经通过线路加密和压缩(相对于相关配置),我们可以直接提取数据并将其写入文件。
    //委托给BlockStoreClient拉取数据
    if (req.size > maxReqSizeShuffleToMem) {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, this)
    } else {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, null)
    }
  }

}

 3、BlockStoreClient

接口:用于从Executor或外部服务读取shuffle文件和RDD块。

public abstract class BlockStoreClient implements Closeable {

  //从远程节点异步获取块序列,
  //请注意,此API采用一个序列,因此实现可以批处理请求,并且不返回future,因此底层实现可以在获取块的数据后立即调用onBlockFetchSuccess,而不是等待获取所有块。
  //DownloadFileManager 以创建和清理临时文件。如果不为null,则远程块将流式传输到temp shuffle文件中以减少内存使用量,否则,它们将保留在内存中。
  public abstract void fetchBlocks(
      String host,
      int port,
      String execId,
      String[] blockIds,
      BlockFetchingListener listener,
      DownloadFileManager downloadFileManager);

}

4、NettyBlockTransferService

用于一次获取一组块的BlockTransferService。BlockTransferService的每个实例中都包含客户端和服务器

private[spark] class NettyBlockTransferService(
    conf: SparkConf,
    securityManager: SecurityManager,
    bindAddress: String,
    override val hostName: String,
    _port: Int,
    numCores: Int,
    driverEndPointRef: RpcEndpointRef = null)
  extends BlockTransferService {


  override def fetchBlocks(
      host: String,
      port: Int,
      execId: String,
      blockIds: Array[String],
      listener: BlockFetchingListener,
      tempFileManager: DownloadFileManager): Unit = {
    if (logger.isTraceEnabled) {
      logger.trace(s"Fetch blocks from $host:$port (executor id $execId)")
    }
    try {
      val maxRetries = transportConf.maxIORetries()
      val blockFetchStarter = new RetryingBlockTransferor.BlockTransferStarter {
        override def createAndStart(blockIds: Array[String],
            listener: BlockTransferListener): Unit = {
          assert(listener.isInstanceOf[BlockFetchingListener],
            s"Expecting a BlockFetchingListener, but got ${listener.getClass}")
          try {
            //创建一个Netty客户端
            val client = clientFactory.createClient(host, port, maxRetries > 0)
            //拉取数据
            new OneForOneBlockFetcher(client, appId, execId, blockIds,
              listener.asInstanceOf[BlockFetchingListener], transportConf, tempFileManager).start()
          } catch {
            ......
          }
        }
      }

      if (maxRetries > 0) {
        // 请注意,此取数器将正确处理maxRetries==0;我们避免使用它,以防代码中出现错误。一旦确定了稳定性,就应该删除if语句。
        new RetryingBlockTransferor(transportConf, blockFetchStarter, blockIds, listener).start()
      } else {
        blockFetchStarter.createAndStart(blockIds, listener)
      }
    } catch {

    }
  }


}

5、TransportClientFactory创建TransportClient

创建一个连接到给定远程主机/端口的 TransportClient

我们维护一个 TransportClient 数组(大小由spark.shuffle.io.numConnectionsPerPeer决定)并随机选择一个 TransportClient来使用。如果以前没有在随机选择的地点创建 TransportClient,此函数将创建一个新 TransportClient并将其放置在那里。

spark.shuffle.io.numConnectionsPerPeer 默认值 1 (仅限Netty:重用主机之间的连接,以减少大型群集的连接累积。对于具有多个硬盘和少量主机的集群,这可能会导致并发性不足,无法使所有磁盘饱和,因此用户可以考虑增加此值。)

如果fastFail参数为true,则在fast fail时间窗口内(io等待重试超时的95%)最后一次尝试相同地址失败时立即失败。假设调用方将处理重试。

在创建新的TransportClient之前,我们将执行在此工厂注册的所有TransportClientBootstrap

这将阻塞,直到成功建立连接并完全引导。

public class TransportClientFactory implements Closeable {

  public TransportClient createClient(String remoteHost, int remotePort, boolean fastFail)
      throws IOException, InterruptedException {
    //首先从连接池获取连接。如果未找到或未激活,请创建一个新的。在这里使用未解析的地址以避免每次创建客户端时进行DNS解析。
    final InetSocketAddress unresolvedAddress =
      InetSocketAddress.createUnresolved(remoteHost, remotePort);

    // 如果还没有 ClientPool 就创建一个
    //第一层:ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
    //第二层:ClientPool {TransportClient[] clients }
    //也就是会为每个远程节点创建多个 TransportClient 然后随机选一个去拉取数据
    //默认只会对一个节点创建一个客户端,如果集群是少主机多磁盘的情况需要调节这个值来提高并发
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if (clientPool == null) {
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }

    int clientIndex = rand.nextInt(numConnectionsPerPeer);
    TransportClient cachedClient = clientPool.clients[clientIndex];

    if (cachedClient != null && cachedClient.isActive()) {
      // 通过更新处理程序的上次使用时间,确保通道不会超时。然后检查TransportClient是否仍然处于活动状态,以防在该代码能够更新内容之前超时。
      TransportChannelHandler handler = cachedClient.getChannel().pipeline()
        .get(TransportChannelHandler.class);
      synchronized (handler) {
        handler.getResponseHandler().updateTimeOfLastRequest();
      }

      if (cachedClient.isActive()) {
        logger.trace("Returning cached connection to {}: {}",
          cachedClient.getSocketAddress(), cachedClient);
        return cachedClient;
      }
    }

    // 如果我们到达这里,就没有打开的现有连接。让我们创建一个新的。多个线程可能在此竞争以创建新连接。仅保持其中一个处于活动状态。
    final long preResolveHost = System.nanoTime();
    final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
    final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
    final String resolvMsg = resolvedAddress.isUnresolved() ? "failed" : "succeed";
    if (hostResolveTimeMs > 2000) {
      logger.warn("DNS resolution {} for {} took {} ms",
          resolvMsg, resolvedAddress, hostResolveTimeMs);
    } else {
      logger.trace("DNS resolution {} for {} took {} ms",
          resolvMsg, resolvedAddress, hostResolveTimeMs);
    }

    synchronized (clientPool.locks[clientIndex]) {
      cachedClient = clientPool.clients[clientIndex];

      if (cachedClient != null) {
        if (cachedClient.isActive()) {
          logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
          return cachedClient;
        } else {
          logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
        }
      }
      // 如果此连接在上次快速失败时间窗口中的最后一个连接失败时快速失败,并且确实失败了,请直接失败此连接。
      if (fastFail && System.currentTimeMillis() - clientPool.lastConnectionFailed <
        fastFailTimeWindow) {
        throw new IOException(
          String.format("Connecting to %s failed in the last %s ms, fail this connection directly",
            resolvedAddress, fastFailTimeWindow));
      }
      try {
        //就是创建一个netty的客户端
        clientPool.clients[clientIndex] = createClient(resolvedAddress);
        clientPool.lastConnectionFailed = 0;
      } catch (IOException e) {
        clientPool.lastConnectionFailed = System.currentTimeMillis();
        throw e;
      }
      return clientPool.clients[clientIndex];
    }
  }

  //创建netty 客户端
  TransportClient createClient(InetSocketAddress address)
      throws IOException, InterruptedException {
    logger.debug("Creating new connection to {}", address);

    Bootstrap bootstrap = new Bootstrap();
    bootstrap.group(workerGroup)
      .channel(socketChannelClass)
      // 禁用Nagle算法,因为我们不希望数据包等待
      .option(ChannelOption.TCP_NODELAY, true)
      .option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs())
      .option(ChannelOption.ALLOCATOR, pooledAllocator);

    if (conf.receiveBuf() > 0) {
      bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
    }

    if (conf.sendBuf() > 0) {
      bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
    }

    final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
    final AtomicReference<Channel> channelRef = new AtomicReference<>();

    bootstrap.handler(new ChannelInitializer<SocketChannel>() {
      @Override
      public void initChannel(SocketChannel ch) {
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
        clientRef.set(clientHandler.getClient());
        channelRef.set(ch);
      }
    });

    // Connect to the remote server
    long preConnect = System.nanoTime();
    ChannelFuture cf = bootstrap.connect(address);

    return client;
  }


}

6、OneForOneBlockFetcher

用于拉取远端节点块的shuffle服务,与之对应的类是OneForOneBlockPusher:用于将块推送到要合并的远程shuffle服务,在《Spark-push-based shuffle》用到

public class OneForOneBlockFetcher {

  //开始拉取数据,调用每个已抓取块的侦听器。给定的消息将使用Java序列化程序序列化,RPC必须返回{@link StreamHandle}。我们将立即发送所有获取请求,而不进行限制。
  public void start() {
    //发送rpc请求
    client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() {
      @Override
      public void onSuccess(ByteBuffer response) {
        try {
          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
          logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);

          // 立即请求所有块——由于[[ShuffleBlockFetcheriator]]中的更高级别的块处理,我们希望请求的总大小是合理的。
          for (int i = 0; i < streamHandle.numChunks; i++) {
            if (downloadFileManager != null) {
              client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
                new DownloadCallback(i));
            } else {
              client.fetchChunk(streamHandle.streamId, i, chunkCallback);
            }
          }
        } catch (Exception e) {
          logger.error("Failed while starting block fetches after success", e);
          failRemainingBlocks(blockIds, e);
        }
      }

      @Override
      public void onFailure(Throwable e) {
        logger.error("Failed while starting block fetches", e);
        failRemainingBlocks(blockIds, e);
      }
    });
  }

}

五、封装成iterator

在BlockStoreShuffleReader中的read()源码中已经看到,当数据准备完成后,需要把这些数据封装成kv迭代器,再转化成一个可中断的迭代器(支持任务取消),再判断算子是否定义了聚合器并判断map端是否开启了聚合

如果map端开启了聚合:那么走这个函数

聚合器有三个方法:
       1、createCombiner :用于创建聚合的初始值
       2、mergeValue :用于将新值合并到聚合结果中 
       3、mergeCombiners :用于合并多个mergeValue函数的输出 

如果map端有聚合,那么就已经把1和2都做了,reduce端读的时候只做3就可以了,例如reduceByKey算子

  val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
  dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)

//---------------aggregator-------------------------
  def combineCombinersByKey(
      iter: Iterator[_ <: Product2[K, C]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }

否则走这个函数

而map端如果没有聚合那么在reduce端就要把三个函数都走一遍,例如groupByKey算子

  val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
  dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)

//---------------aggregator-------------------------
  def combineValuesByKey(
      iter: Iterator[_ <: Product2[K, V]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }

如果算子定义了排序,需要对迭代器再进行排序

    val resultIter: Iterator[Product2[K, C]] = dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAllAndUpdateMetrics(aggregatedIter)
      case None =>
        aggregatedIter
    }

最后返回一个可中断的迭代器供后续的计算使用

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值