SparkCore — ShuffleReader过程

Shuffle Reader

  在之前的博客中,分析了shuffle map端的操作,map最终会将输出文件信息封装为一个MapStatus发送给Driver,然后ResultTask或ShuffleMapTask在拉取数据的时候,会先去Driver上拉取自己要读取数据的信息,比如在哪个节点上,以及在文件中的位置。下面我们来分析一下ShuffleReader,首先Map操作结束之后产生的RDD是ShuffledRDD,它会调用ShuffleManager的getReader()方法,这个方法里面传入了上一个stage的信息,拉取文件信息的offset,接着调用它的read()方法:

ShuffledRDD的compute()和BlockStoreShuffleReader的read()方法
 override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
	 // ResultTask或ShuffleMapTask,在生成ShuffledRDD并处理的时候
    // 会调用它的compute方法,来计算当前这个RDD的partition的数据
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    // 这里调用ShuffleManager的getReader的read()方法,拉取ResultTask或ShuffleMapTask所需的数据
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
}

override def read(): Iterator[Product2[K, C]] = {
    // BlockStoreReader实例化的时候,传入的参数会获取MapOutputTracker对象,
    // 调用其getMapSizesByExecutorId方法,创建一个Iterator,用于遍历待获取数据的位置信息。
    // 注意传入的参数,shuffleId,代表上一个stage;
    // startPartition:是当前需要的数据在输出文件中的起始offset,endPartition:是结束offset
    // 通过这两个限制从MapOutputTracker上拉取所需信息在节点上的位置信息
    // 在实例化ShuffleBlockFetcherIterator的时候,会调用它的initialize()方法,
    // 在这个方法里面,会根据拉取到的文件位置信息去对应的worker节点的BlockManager上拉取数据。
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 获取数据的位置信息
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // 创建数据输入流读取数据,以及是否需要解压等
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }
    // 创建序列化实例
    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    // 将读取到的数据进行反序列化操作
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    // 下面就是对数据的一些操作,比如是否需要聚合,排序等等
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

  这里最重要的在实例化ShuffleBlockFetcherIterator()的时候,就会去远程读取数据,这里面有两个重要的方法,一个是获取要拉取文件的信息getMapSizesByExecutorId(),还有一个是ShuffleBlockFetcherIterator在实例化的时候调用的initialize()方法,下面我们先分析如何拉取当前ResultTask(或ShuffleMapTask)所需信息的位置:

MapOutputTracker的getMapSizesByExecutorId()方法

  首先我们看一下源码:

 def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    // 获取数据的位置信息
    val statuses = getStatuses(shuffleId)
    // Synchronize on the returned array because, on the driver, it gets mutated in place
    statuses.synchronized {
      // 将获取到的数据存储到BlockManager上。
      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }

  其实这个方法里面封装了两个两个子方法,一个是获取数据位置信息的getStatus;还有一个就是将获取到的信息提取出来放入队列中。
下面我们先看getStatus()

getStatuses
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
    // 获取shuffleId在输出文件中的每个partition写入位置offset,
    // 看一下当前缓存是否有之前拉取的数据
    val statuses = mapStatuses.get(shuffleId).orNull
    // 如果不存在
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      val startTime = System.currentTimeMillis
      var fetchedStatuses: Array[MapStatus] = null
      // 有可能其他的ResultTask在拉取这个shuffleId的数据,等待对方拉取完成
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        while (fetching.contains(shuffleId)) {
          try {
            // 等待唤醒
            fetching.wait()
          } catch {
            case e: InterruptedException =>
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        // 再次获取一遍
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          // 将其加入等待队列,下一次优先进行数据拉取
          fetching += shuffleId
        }
      }

      // 获得当前拉取数据的权限
      if (fetchedStatuses == null) {
        // We won the race to fetch the statuses; do so
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        // This try-finally prevents hangs due to timeouts:
        try {
          // 发送GetMapOutputStatuses消息,从MapOutputTracker上拉取数据
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
          // 将获取的数据反序列化
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          // 将拉取到的数据
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            // 清除当前等待的shuffleId
            fetching -= shuffleId
            // 唤醒其他线程
            fetching.notifyAll()
          }
        }
      }
      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
        s"${System.currentTimeMillis - startTime} ms")

      if (fetchedStatuses != null) {
        return fetchedStatuses
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      // 假如缓存有之前拉取的数据,那么直接返回
      return statuses
    }
  }

  首先看一下当前缓存是否已经包含这个shuffleId输出文件信息,假如包含,那么就这就返回即可。假设没有,如果fetching等待队列中包含当前需要拉取的shuffleId,先阻塞在这边等待其他ResultTask(或ShuffleMapTask)获取完成;被唤醒以后接着获取一次status,假设还没有获取到,那么就开始获取。开启拉取数据信息,使用的是askTracker()方法,参数是GetMapOutputStatus信息,它向Driver的MapOutputTracker发送这条信息,去获取当前这个ShuffleId的输出文件信息,Driver上的MapOutputTracker接收到这条信息后,就会获取当前这shuffleId的相关信息,然后在将获取到的信息发送给当前这个ResultTask(或ShuffleMapTask)。这里fetchedStatuses就是Driver端MapOutputTracker发送过来的待获取数据的位置信息。然后将数据反序列化存入map缓存中;接着在唤醒其他等待线程。
  在获取到需要拉取数据的位置信息之后,就调用convertMapStatuses()解析刚刚获取到的位置信息,将要拉取的位置信息提取出来,放入队列中,并返回。
  上面这个就获取到了需要拉取数据的位置信息,那么下一步就是去拉取数据,拉取数据的过程就在实例化ShuffleBlockFetcherIterator的时候,调用的initialize()方法中。

ShuffleBlockFetcherIterator的初始化方法initialize()
/**
    *   将这个方法作为入口,开始拉取ResultTask对应的多份数据
    */
  private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.
    // 切分本地和远程Block
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    // 切分完Block之后,进行随机排序操作
    fetchRequests ++= Utils.randomize(remoteRequests)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    // 循环往复拉取数据,只要发现数据还没有拉取完,就发送请求到远程拉取数据
    // 这里有一个参数比较重要,就是maxBytesInFlight,代表ResultTask最多能拉取多少数据
    // 到本地,就要开始进行自定义的reduce算子的处理
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    // 获取本地的数据
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

  首先对将要拉取的数据信息进行区分,切分为本地和远程拉取,首先拉取远程worker节点上的数据,fetchUpToMaxBytes(),它会不断的拉取数据,直到数据拉取完或者当前拉取的缓存以及满了(默认48M,maxBytesInFlight),然后接着调用fetchLocalBlocks(),拉取在本地节点上的数据。这样这个ResultTask(或ShuffleMapTask)的数据就拉取到本地缓存了。这里我们先不对fetchUpToMaxBytes和fetchLocalBlocks做详细的分析了。
  总结一下,这里主要是和Driver端的MapOutputTracker进行通信,获取当前ResultTask(或ShuffleMapTask)要拉取的文件的位置信息,从获取到的文件位置信息里提取出当前这个Task所需的位置信息,然后通过BlockManager去远程或本地拉取需要的信息这里有个参数需要注意一下(spark.reducer.maxSizeInFlight,默认48M,代表当前reduce端最大能存储的拉取数据缓存大小)。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值