spark源码解析之四、shuffle那些事儿

关于任务切分与运行可以参考:spark源码解析之三、任务切分与运行
spark运行流程分为资源环境准备和任务提交运行两个步骤,两个步骤交叉进行,当前以任务提交为主线进行源码分析。
资源环境准备线,可以参考spark源码解析之二、计算资源准备
关于源代码的前期准备可以参考:spark源码解析之一、整体概述

一、shuffle概述

Shuffle 机制是 Spark Core 的核心内容。在 Stage 和 Stage 之间,Spark 需要 Shuffle 数据。这个流程包含上一个 Stage 上的 Shuffle Write,中间的数据传输,以及下一个 Stage 的 Shuffle Read。如下图所示:
在这里插入图片描述强调一点,shuffle分两个过程:shuffle读和shuffle写,shuffle读发生在一个stage的开始,shuffle写发生在ShuffleMapStage的结尾,shuffle读在前,shuffle写在后,且分布在一个stage的首尾,在后续的原码中也会有所体现。
在前几篇文章中分析了spark资源准备和stage切分、task提交的源代码,那么我们接下来分析任务task运行过程中必不可少的shuffle过程的原码。
task任务提交之后,在stage的最后阶段就是shuffle数据落盘的过程,数据落盘完成则标着者下一个stage的开始,下一个stage的数据来源就是上一个stage的shuffle数据文件。
spark中的task分为两类ResultTask和ShuffleMapTask,ResultTask作为最终阶段的task,写的过程主要是一些行动算子,不同的行动算子具有不同的逻辑,不具有代表性。而ShuffleMapTask具备了shuffle的通用功能读和写,所以主要从ShuffleMapTask的runtask开始。
shuffle从代码中可以看到,我们暂时先关注三件事情,ShuffleWriter的创建、Shuffle读以及shuffle写。

1.1 ShuffleMapTask.runTask

override def runTask(context: TaskContext): MapStatus = {
   ....
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
	  //一:ShuffleWriter的创建
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      //二:Shuffle写主要是rdd.iterator
      //三:shuffle读主要是writer.write
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
   ....
    }
}

二、ShuffleWriter的创建

注册ShuffleHandle,不同的ShuffleHandle用于创建不同的ShuffleWriter。从这个判断方法中可以看到,其实spark框架是优先判断是否符合bypass机制,如果不符合在判断是否是序列化shuffle机制,如果两者都不符合才是baseshuffle机制。其实这个过程就像去某地,地图肯定会规划多条路径,最终给定一条最优解,也从侧面说明这几种机制的顺序实现起来是越来越麻烦。

2.1 SortShuffleManager.registerShuffle

   override def registerShuffle[K, V, C](
      shuffleId: Int,
      numMaps: Int,
      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
		  //判断是否使用忽略合并排序
    if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
      // need map-side aggregation, then write numPartitions files directly and just concatenate
      // them at the end. This avoids doing serialization and deserialization twice to merge
      // together the spilled files, which would happen with the normal code path. The downside is
      // having multiple files open at a time and thus more memory allocated to buffers.
      new BypassMergeSortShuffleHandle[K, V](
        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
      // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
      new SerializedShuffleHandle[K, V](
        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
    } else {
      // Otherwise, buffer map outputs in a deserialized form:
      new BaseShuffleHandle(shuffleId, numMaps, dependency)
    }
  }

首先关注一下判断bypass机制的逻辑。

2.2 SortShuffleManager.shouldBypassMergeSort

    def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
    //RDD依赖中已经明确定义map端预聚合
    if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      false
    } else {
      val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
      //task分区数量小于200,bypassMergeThreshold默认值为200,可以调整
	  dep.partitioner.numPartitions <= bypassMergeThreshold
    }
  }

从源码逻辑可以发现bypass机制需要符合一些条件,shuffle算子不能支持map端聚合且stage中task数量不高于200的阈值。这个200的阈值可以通过spark.shuffle.sort.bypassMergeThreshold动态配置,在生产环境中如果运行环境资源允许,可以调大该阈值的配置,以便命中bypass的可能性,从而提升task运行效率。

2.3 SortShuffleWriter.getWriter

//不同handle创建不同shuffleWriter
    override def getWriter[K, V](
      handle: ShuffleHandle,
      mapId: Int,
      context: TaskContext): ShuffleWriter[K, V] = {
    numMapsForShuffle.putIfAbsent(
      handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
    val env = SparkEnv.get
    handle match {
      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
        new UnsafeShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          context.taskMemoryManager(),
          unsafeShuffleHandle,
          mapId,
          context,
          env.conf)
      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
        new BypassMergeSortShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          bypassMergeSortHandle,
          mapId,
          context,
          env.conf)
      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
    }
  }

根据不同的ShuffleHandle创建shuffleWrite。

三、Shuffle读

我们从前边的stage划分可以知道,stage划分的分界线是款依赖算子,一个stage的task数量取决于款依赖的分区数量,也就是说临界算子有多少个分区就会划分多少个task,其实每个task就是这一stage从前到后的算子封装,每个task封装的逻辑一样,只是读取处理的数据不同而已,所以,shuffle读就是正常算子的读,没什么特别,只是读取的是上一个stage的shuffle数据而已。
我们从writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])入手,最终追踪到rdd.compute方法,这个compute方法有多种实现,但是只有ShuffledRDD这一类的RDD才会有shuffle的读,所以就从ShuffledRDD.compute方法看起。

3.1 ShuffledRDD.compute

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
		//注意这里特别关注一下getReader的参数
		//dep.shuffleHandle用于获取前一阶段shuffle数据文件元数据
	    //split.index  上一阶段shuffle结果数据切片索引开始
		//split.index + 1  上一阶段shuffle结果数据切片索引结束
		//从这里可以看出一个reader只读取上一阶段shuffle结果数据的一个分区,这个shuffle溢写数据文件逻辑保持一致
		//即上一阶段shufflewrite会将shuffle结果数据根据下一阶段分区数也是task数量进行落盘,一个task一份分区"一段数据"
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

3.2 BlockStoreShuffleReader.read

  override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
	 //注意这里第一个参数是从handle中获取的shuffleId,即上一个阶段的shuffle结果数据的标记
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // 设置每次拉取的数据大小,默认48M
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
	  //设置每次拉取的数据量最大值,默认为Int最大值
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

    // 根据配置的压缩和解码方式包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }
    val serializerInstance = dep.serializer.newInstance()

    // 为读取到的每个流创建KV迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    //读取数据后,更新读取量,后续用于记录跟踪和评估
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      //通过map操作每条记录,其实什么都没做,只是记录了读取的数据量
	recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // 设置可中断迭代器,以便取消task
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        //需要对读取的数据在map端进行聚合,比如reduceByKey会在map进行预聚合
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
		//按照key对数据集进行聚合
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        //不关注预聚合,只关注value,将相同key的values进行归集,比如groupByKey
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
	  //未指定聚合器的映map端合并,直接什么都不做,数据保持原样,比如sortByKey对value没任何要求
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 如果指定排序器,则对数据集进行排序输出
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // 创建排序器ExternalSorter对结果集进行排序
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
          //如果需要排序,则需要使用排序器对数据集进行排序
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
		//不需要排序,数据保持原样
        aggregatedIter
    }
  }

3.3 ExternalSorter.insertAll

参考4.1.1以及后续方法。

四、Shuffle写之SortShuffleWriter

SortShuffleWriter的写方法主要做了三件事情
1.对所有记录写入内存或者文件;
2.对数据溢写并合不同分区数据文件;
3.对数据文件创建索引文件。

override def write(records: Iterator[Product2[K, V]]): Unit = {
	  //注意这个sorter决定了后边溢写的方式
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
		  //map端预聚合需要对key进行排序 dep.keyOrdering
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      //无需map端预聚合,不需要对key进行排序 ordering = None
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
	 //对所有记录进行排序
    sorter.insertAll(records)
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
	//创建临时文件
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

根据write代码逻辑,首先会根据shuffle中的算子是否需要map端预聚合,生成不同参数的排序器,这个排序器决定了后续数据的排序和数据文件的溢写。接下来就使用该排序器对所有数据记录进行归集操作,要么写入内存,如果内存不足以存储,则溢写到磁盘,注意这里的溢写,只要内存满就会写到文件,所以会产生很多的小文件,这一点在spill方法中会有所体现。

4.1 对所有记录写入内存或者文件

4.1.1 ExternalSorter.insertAll
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    val shouldCombine = aggregator.isDefined
    //是否需要预聚合
    if (shouldCombine) {
      //map端预聚合使用AppendOnlyMap数据结构
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      //定义聚合函数,通过map数据结构对相同key数据进行聚合
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      //从这里可以看到,其实spark处理数据最终也是对数据逐条进行处理,没什么特殊之处
      while (records.hasNext) {
		  //记录当前处理的记录数,该方法会进行+1操作
        addElementsRead()
        kv = records.next()
			//由于map端预聚合,需要根据key值,对map中数据进行聚合更新
        map.changeValue((getPartition(kv._1), kv._1), update)
        //判断是否需要溢写操作,usingMap = true,因为只有map结构才能很好低支持预聚合
        maybeSpillCollection(usingMap = true)
      }
    } else {
    //如果不需要预聚合,直接插入即可
      while (records.hasNext) {
		  //记录当前处理的记录数
        addElementsRead()
        val kv = records.next()
		//直接插入buffer内存
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        //判断是否需要溢写操作,usingMap = false,这里直接使用缓冲区buffer
        maybeSpillCollection(usingMap = false)
      }
    }
  }

这里一定要注意,不管是哪种方式都会有是否需要溢写的判断,并且判断是在while循环中,也就是说,每条记录进来都会走一遍后续的流程。

4.1.2 ExternalSorter.maybeSpillCollection
  private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
	  //判断是否需要spill即溢写文件,注意这里是每来一条记录都会进行一次判断
      if (maybeSpill(map, estimatedSize)) {
        //如果溢写成功map结构初始化
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
	  //判断是否需要spill即溢写文件,注意这里是每来一条记录都会进行一次判断
      if (maybeSpill(buffer, estimatedSize)) {
       //如果溢写成功buffer缓冲区初始化
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

不管是使用map还是使用buffer封装数据集,最终都要判断数据集是否达到溢写阈值,注意这里一直在内存中,并没有文件什么事情,且对数据集中每条数据都会进行一次判断,调用这个方法是在一个while循环中,可以参考3.1.1。

4.1.3 Spillable
@volatile private[this] var myMemoryThreshold = initialMemoryThreshold
//默认内存缓冲区大小为5M,可以通过spark.shuffle.spill.initialMemoryThreshold动态配置
private[this] val initialMemoryThreshold: Long =
    SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
//默认强制溢写数量为Long最大值,默认值很大,没有优化的必要    
private[this] val numElementsForceSpillThreshold: Long =
    SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)
    
......

  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
	//写入数量是否是32的倍数,且当前内存是仍然允许写入
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      //从shuffle内存池中获取最多两倍的当前内存 
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted
      //进来一条数据后,如果内存数量大于内存阈值则溢写文件
      shouldSpill = currentMemory >= myMemoryThreshold
    }
	 //这里是一个强制性判断,如果写入数量大于强制溢出阈值
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
	  //溢写之后数量清零
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      //释放内存
      releaseMemory()
    }
    shouldSpill
  }

每进来一条数据都要判断一下是否符合溢写的条件,如果需要溢写,就会进行溢写操作。

4.1.4 ExternalSorter.spill
//spills是一个数据集:数组,存放的是溢写临时文件
private val spills = new ArrayBuffer[SpilledFile]
.....
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    //注意溢写一次就会写入一个文件
    spills += spillFile
  }
......
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {
//  因为这些文件可能在shuffle过程中被读取,所以它们的压缩必须使用spark.shuffle.compress压缩方式,
//  而不是shuffle溢写的压缩方式,因此我们需要在这里使用createTempShuffleBlock;
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    // 溢写文件之后需要更新变量
    var objectsWritten: Long = 0
    val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
    val writer: DiskBlockObjectWriter =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)

    //批次大小
    val batchSizes = new ArrayBuffer[Long]

    //每个分区拥有元素数量列表
    val elementsPerPartition = new Array[Long](numPartitions)

    //刷写之后,更新相关变量
    //这里只是定义了这个方法,被调用才会真正执行
    def flush(): Unit = {
      val segment = writer.commitAndGet()
      batchSizes += segment.length
      _diskBytesSpilled += segment.length
      objectsWritten = 0
    }

    var success = false
    try {
      while (inMemoryIterator.hasNext) {
        val partitionId = inMemoryIterator.nextPartition()
        require(partitionId >= 0 && partitionId < numPartitions,
          s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
        inMemoryIterator.writeNext(writer)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1

        //写入累计数量已经达到批次阈值,则进行刷写,这个阈值是一个调优的对象
        if (objectsWritten == serializerBatchSize) {
          flush()
        }
      }
      if (objectsWritten > 0) {
        flush()
      } else {
        writer.revertPartialWritesAndClose()
      }
      success = true
    } finally {
      if (success) {
		  //关闭流资源
        writer.close()
      } else {     
        writer.revertPartialWritesAndClose()
        if (file.exists()) {
          if (!file.delete()) {
            logWarning(s"Error deleting ${file}")
          }
        }
      }
    }

    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
  }

4.2 对数据溢写并合不同分区数据文件

4.2.1 ExternalSorter.writePartitionedFile
def writePartitionedFile(
      blockId: BlockId,
      outputFile: File): Array[Long] = {

    //跟踪输出文件中每个范围的位置
    val lengths = new Array[Long](numPartitions)
    val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
      context.taskMetrics().shuffleWriteMetrics)

    if (spills.isEmpty) {
      //没有溢写,说明只需要处理内存数据
      val collection = if (aggregator.isDefined) map else buffer
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      while (it.hasNext) {
        val partitionId = it.nextPartition()
        while (it.hasNext && it.nextPartition() == partitionId) {
          it.writeNext(writer)
        }
		 //将内存中数据写到文件,注意这里不是溢写
        val segment = writer.commitAndGet()
        lengths(partitionId) = segment.length
      }
    } else {
      //按分区获取迭代器并直接写入到数据文件,在写入时使用同一个writer,说明写入了一个文件
	  //从这里我们可以得出结论:一个task只写了一个数据文件
      for ((id, elements) <- this.partitionedIterator) {
        if (elements.hasNext) {
          for (elem <- elements) {
            writer.write(elem._1, elem._2)
          }
          val segment = writer.commitAndGet()
          lengths(id) = segment.length
        }
      }
    }

    writer.close()
    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

    lengths
  }
4.2.2 groupByPartition
 //从处理逻辑来看这里返回的数据结构的key根分区号相关
  private def groupByPartition(data: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] =
  {
    val buffered = data.buffered
    (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
  }
4.2.3 merge
 // 合并已排序的文件,最终写一个新的文件或返回数据,这里的merge的是溢写的临时数据文件
 private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    val readers = spills.map(new SpillReader(_))
    val inMemBuffered = inMemory.buffered
    //这里一定要注意,这里遍历的是分区数量,返回结果key就是分区索引位
    (0 until numPartitions).iterator.map { p =>
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
	  //这里读取文件并合并分区内数据,将数据归集起来,并没有写入
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      if (aggregator.isDefined) {
        //跨分区执行部分聚合,主要用于处理给定的比较器。默认情况下,同一个key会放到同一个分区中,那么
		//如果使用自定义的的排序器,不同的key可能相等,那就需要跨分区操作。
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
        //不定义排序规则但是需要进行排序,由于同一个分区会有多个临时数据文件,所以需要在多个文件之间进行合并并排序
        (p, mergeSort(iterators, ordering.get))
      } else {
		//如果排序聚合什么都没有,则将数据压平
        (p, iterators.iterator.flatten)
      }
    }
  }

4.3 对数据文件创建索引文件

4.3.1 IndexShuffleBlockResolver.writeIndexFileAndCommit
 //临时文件合并后,些数据文件的索引文件
  def writeIndexFileAndCommit(
      shuffleId: Int,
      mapId: Int,
      lengths: Array[Long],
      dataTmp: File): Unit = {
    val indexFile = getIndexFile(shuffleId, mapId)
    val indexTmp = Utils.tempFileWith(indexFile)
    try {

      //维护索引临时文件
      val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
      Utils.tryWithSafeFinally {
        //取每个数据块的长度,需要将其转换为偏移量然后维护索引临时文件
        var offset = 0L
        out.writeLong(offset)
		//注意这里的lengths是分区的个数,也就是下一个阶段的task个数,从这里可以明确一个task只有一个索引文件
        for (length <- lengths) {
          offset += length
          out.writeLong(offset)
        }
      } {
        out.close()
      }

      val dataFile = getDataFile(shuffleId, mapId)
      //每个执行器只有一个IndexShuffleBlockResolver,此同步确保以下检查和重命名是原子的
      synchronized {
		  //判断索引文件是否已经存在
        val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
        if (existingLengths != null) {
          //存在则说明同一任务的另一次尝试已经成功维护了索引文件,则需要将临时索引文件和临时
		  //数据文件删除即可
          System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
          if (dataTmp != null && dataTmp.exists()) {
            dataTmp.delete()
          }
          indexTmp.delete()
        } else {
          //这是为该task的第一次成功尝试,直接使用现有的索引和数据文件
		  //索引文件存在,则删除
          if (indexFile.exists()) {
            indexFile.delete()
          }
		  //数据文件则删除
          if (dataFile.exists()) {
            dataFile.delete()
          }
		  //将索引临时文件重命名为索引文件
          if (!indexTmp.renameTo(indexFile)) {
            throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
          }
			//将数据文件删除,将现有的数据临时文件重命名为数据文件
          if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
            throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
          }
        }
      }
    } finally {
      if (indexTmp.exists() && !indexTmp.delete()) {
        logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
      }
    }
  }

五、Shuffle写之UnsafeShuffleWriter

未完待续。

六、Shuffle写之BypassMergeSortShuffleWriter

未完待续。
由于对spark理解有限,中间难免会有错误,还请各位指正,共同讨论学习。后续随着对spark理解的深入,会继续修改文章。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值