Spark Shuffle源码分析

task.run.runTask->ShuffleMapTask.runTask->writer.write
writer 有 HashShuffleWriter和SortShuffleWriter
本章分析 HashShuffleWriter

Shuffle Write

  /**
   *  Write a bunch of records to this task's output 
   *  将每个shuffleMapTask计算出来的新的RDD的partition数据写入本地磁盘
   */
  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {

    /**
     *  首先判断,是否需要在map端本地聚合
     *  如果reduceByKey这种操作,它的dep.aggregator.isDegined就是true
     *  那么就会进行map端的本地聚合
     */
    val iter = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // 本地聚合  如:(hello,1) (hello,1) ---> (hello,2)
        dep.aggregator.get.combineValuesByKey(records, context)
      } else {
        records
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      records
    }

    /**
     *  如果要本地聚合,那么先本地聚合
     *  然后遍历数据
     *  对每个数据,调用partitioner,默认是HashPartitioner生成bucketId
     *  也就是决定了,每一份数据,要写入那个bucket中
     */
    for (elem <- iter) {
      val bucketId = dep.partitioner.getPartition(elem._1)

      /**
       *  获取到bucketId后,会调用shuffleBlockManager.forMapTask()方法,生成bucketId对应的writer,
       *  然后用writer将数据写入buket
       */
      shuffle.writers(bucketId).write(elem)
    }
  }

=> shuffle.writers -> FileShuffleBlockManager.forMapTask.writers
  /**
   * 给每个map task获取一个shufflewritegroup
   */
  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
      writeMetrics: ShuffleWriteMetrics) = {
    new ShuffleWriterGroup {
      shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
      private val shuffleState = shuffleStates(shuffleId)
      private var fileGroup: ShuffleFileGroup = null

      /**
       * shuffle的两种模式:
       *  1)开启consolication机制:consolidateShuffleFiles=true,不会给每个bucket都获取一个独立的文件
       *  而是为这个bucket获取一个ShuffleGroup的writer
       *  2) 未开启consolication机制 consolidateShuffleFiles=false
       */
      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
        fileGroup = getUnusedFileGroup()
        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>

          /**
           *  首先用shuffleId,mapId,bucketId(reduceId)生成一个唯一的ShuffleBlockId
           *  然后用bucketId,来调用shufflefileGroup的apply函数,为bucket获取一个shufflefilegroup
           */
          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)

          /**
           *  针对ShuffleFileGroup获取一个writer
           *  如果开启了consolidation机制,对于每一个bucket,都会获取一个针对ShuffleFileGroup的writer
           *  而不是一个独立的ShuffleBlockFile的writer
           *  这样就实现了多个shuffleMapTask输出数据的合并
           */
          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
            writeMetrics)
        }
      } else {
        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
          // 获取一个代表了要写入的本地磁盘文件的blockfile
          val blockFile = blockManager.diskBlockManager.getFile(blockId)
          // Because of previous failures, the shuffle file may already exist on this machine.
          // If so, remove it.
          if (blockFile.exists) {
            if (blockFile.delete()) {
              logInfo(s"Removed existing shuffle file $blockFile")
            } else {
              logWarning(s"Failed to remove existing shuffle file $blockFile")
            }
          }
          blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
        }
      }

==>blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,writeMetrics)
bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024    // 默认32kb
--> BlockManager.getDiskWriter
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
      writeMetrics)

DiskBlockObjectWriter.write-> open
  override def open(): BlockObjectWriter = {
    if (hasBeenClosed) {
      throw new IllegalStateException("Writer already closed. Cannot be reopened.")
    }
    // java 文件输出流
    fos = new FileOutputStream(file, true)
    ts = new TimeTrackingOutputStream(fos)
    channel = fos.getChannel()

    /**
     * java 缓冲流 ,中传入 bufferSize,缓冲大小,当内存中数据达到这个值时就会异步写入磁盘
     * 至此,spark shufflewrite 最终调用 BufferedOutputStream  实现write
     */
    bs = compressStream(new BufferedOutputStream(ts, bufferSize))
    objOut = serializer.newInstance().serializeStream(bs)
    initialized = true
    this
  }

ShuffleReader

ShuffledRDD.compute 方法 调用 ShuffleReader
  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {

    /**
     *  ResultTask或者ShuffleMapTask,在执行到ShuffledRdd时,肯定会调用ShuffledRDD的compute方法
     *  来计算当前这个RDD的partition的数据
     *  在这里会调用shufflemanager的getReader方法,获取一个HashShuffleReader
     *  然后调用他的read方法,拉取该resultTask/shuffleMapTask需要聚合的数据
     */
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    // TODO  HashShuffleReader.read
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }
=> HashShuffleReader.read
  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)

    /**
     *  TODO fetch
     *  resultTask在拉取数据时,其实会用BlockStoreShuffleFetcher来从DAGScheduler的MapOutputTrackerMaster
     *  中获取自己想要的数据的信息,然后底层再通过blockManager从对应的位置拉取需要的数据
     */
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

==>BlockStoreShuffleFetcher.fetch
  def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer)
    : Iterator[T] =
  {
    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
    val blockManager = SparkEnv.get.blockManager

    val startTime = System.currentTimeMillis
    
    /**
     * 重点
     * 拿到了全局的MapOutTrackerMaster的引用
     * 然后调用getServerStatuses方法,传入 shuffleId和reduceId
     * shuffleId 可以代表当前这个stage的上一个stage,shuffle分为两个stage:
     *   shuffle write 发生在上一个stage中
     *   shuffle read发生在当前的stage中
     * 
     * 理解:
     * 首先通过shuffleId可以限制上上一个stage的所有shuffleMapTask的输出的MapStatus
     * 接着,通过reduceId(bucketId)来限制从每个MapStatus中获取当前这个ResultTask需要
     * 获取的每个ShuffleMapTask的输出文件的信息
     * 
     * 这个getServerStatuses一定走远程网络通信的,因为要联系driver上的DAGScheduler的MapOutputTrackerMaster
     *
     * // TODO getServerStatuses
     */
    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)

    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
      shuffleId, reduceId, System.currentTimeMillis - startTime))

    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
    for (((address, size), index) <- statuses.zipWithIndex) {
      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
    }

    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
      case (address, splits) =>
        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
    }

    def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
      val blockId = blockPair._1
      val blockOption = blockPair._2
      blockOption match {
        case Success(block) => {
          block.asInstanceOf[Iterator[T]]
        }
        case Failure(e) => {
          blockId match {
            case ShuffleBlockId(shufId, mapId, _) =>
              val address = statuses(mapId.toInt)._1
              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
            case _ =>
              throw new SparkException(
                "Failed to get block " + blockId + ", which is not a shuffle block", e)
          }
        }
      }
    }

    
    /**
     * ShuffleBlockFetcherIterator构造后,在其内部就直接根据拉取到的地理位置信息,
     * 通过blockManager去远程的shuffleMapTask所在的节点的blockManager去拉取数据
     *
     *  TODO ShuffleBlockFetcherIterator.initialize
     */
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      SparkEnv.get.blockManager.shuffleClient,
      blockManager,
      blocksByAddress,
      serializer,
      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
    val itr = blockFetcherItr.flatMap(unpackBlock)

    // 最后,将拉取到的数据进行一些转换和封装 返回
    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
      context.taskMetrics.updateShuffleReadMetrics()
    })

===>SparkEnv.get.mapOutputTracker.getServerStatuses -> MapOutputTracker.getServerStatuses
 def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        // 不断去拉取shuffleId对应的数据,只要还没拉到,死循环,等待
        while (fetching.contains(shuffleId)) {
          try {
            fetching.wait()
          } catch {
            case e: InterruptedException =>
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId
        }
      }

      if (fetchedStatuses == null) {
        // We won the race to fetch the output locs; do so
        logInfo("Doing the fetch; tracker actor = " + trackerActor)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes =
            askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      if (fetchedStatuses != null) {
        fetchedStatuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
        }
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      statuses.synchronized {
        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
      }
    }
  }

=> 再回到 new ShuffleBlockFetcherIterator -> ShuffleBlockFetcherIterator.initialize
  private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)

    /**
     *  Send out initial requests for blocks, up to our maxBytesInFlight
     *
     *  循环,发现还有数据没有拉取完,就发送请求到远程去拉取
     *  调优参数: max.bytes.in.flight 最多能拉取多少数据到本地就要开始进行reduce操作
     */
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
      sendRequest(fetchRequests.dequeue())
    }

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    // 拉取完了远程数据之后,拉取本地的数据(数据本地化)
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值