Spark源码2.4.2之Shuffle写过程

ShuffleManager介绍

spark之所以比mapReduce的性能高其中一个主要的原因就是对shuffle过程的优化,一方面spark的shuffle过程更好地利用内存(执行内存),另一方面对于shuffle过程中溢写的磁盘文件归并排序和引入索引文件。当然,spark性能高的另一个主要原因还有对计算链的优化,把多步map类型的计算chain在一起,大大减少中间过程的落盘,这也是spark显著区别于mr的地方。Spark新版本的Shuffle管理器默认是SortShuffleManager。以下源码分析中的ShuffleManager也是SortShuffleManager。
SparkEnv初始化部分的代码:

  val shortShuffleMgrNames = Map(
  "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
  "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)

在下面的源码分析中涉及到的类之间的继承或实现关系如下图:
在这里插入图片描述

Shuffle写过程源码分析

shuffle的过程,无非就是两个步骤,写和读。写是在map阶段,将数据按照一定的分区规则归类到不同的分区中,读是在reduce阶段,每个分区从map阶段的输出中拉取属于自己的数据,所以我们分析Shuffle过程的源码基本也可以沿着这个思路。我们先来分析写的过程,因为对于一个完整的shuffle过程,肯定是先写然后才读的。写是在ShuffleMapTask的runTask中进行的。先来看ShuffleMapTask的**runTask()**方法源码:

// ShuffleMapTask.scala
override def runTask(context: TaskContext): MapStatus = {
  // Deserialize the RDD using the broadcast variable.
  val threadMXBean = ManagementFactory.getThreadMXBean
  val deserializeStartTime = System.currentTimeMillis()
  val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
  } else 0L
  val ser = SparkEnv.get.closureSerializer.newInstance()
  // 反序列化RDD和shuffleDependency,关键的步骤
  // rdd: 该task所在Stage的最后一个RDD
  val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( // 1
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
  _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
  } else 0L

  var writer: ShuffleWriter[Any, Any] = null
  try {
    // 获得shuffle管理器
    val manager = SparkEnv.get.shuffleManager
    // 获取一个shuffle写入器
    writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) // 1
    // 这里可以看到rdd计算的核心方法就是iterator方法
    // SortShuffleWriter的write方法可以分为几个步骤:
    // 将上游rdd计算出的数据(通过调用rdd.iterator方法)写入内存缓冲区,
    // 在写的过程中如果超过 内存阈值就会溢写磁盘文件,可能会写多个文件
    // 最后将溢写的文件和内存中剩余的数据一起进行归并排序后写入到磁盘中形成一个大的数据文件
    // 这个排序是先按分区排序,再按key排序
    // 在最后归并排序后写的过程中,没写一个分区就会手动刷写一遍,并记录下这个分区数据在文件中的位移
    // 所以实际上最后写完一个task的数据后,磁盘上会有两个文件:
    // 数据文件和记录每个reduce端partition数据位移的索引文件
    // 2
    writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])  
    // 主要是删除中间过程的溢写文件,向内存管理器释放申请的内存
    writer.stop(success = true).get
  } catch {
    case e: Exception =>
      try {
        if (writer != null) {
          writer.stop(success = false)
        }
      } catch {
        case e: Exception =>
          log.debug("Could not stop writer", e)
      }
      throw e
  }
}

查看ShuffleManager如何获得ShuffleWriter对象。

  • getWriter()
// SortShuffleManager.scala
override def getWriter[K, V](
    handle: ShuffleHandle,
    mapId: Int,
    context: TaskContext): ShuffleWriter[K, V] = {
  numMapsForShuffle.putIfAbsent(
    handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
  val env = SparkEnv.get
  handle match {
    case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
      new UnsafeShuffleWriter(
        env.blockManager,
        shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
        context.taskMemoryManager(),
        unsafeShuffleHandle,
        mapId,
        context,
        env.conf)
    case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
      new BypassMergeSortShuffleWriter(
        env.blockManager,
        shuffleBlockResolver.asInstancOf[IndexShuffleBlockResolver],
        bypassMergeSortHandle,
        mapId,
        context,
        env.conf)
    case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
      new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
  }
}
  • 关于ShuffleHandle对象的类型,该对象是ShuffleDependency的一个属性。该属性值为
// Dependency.scala/ShuffleDependency
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
  shuffleId, _rdd.partitions.length, this)

在ShuffleManager是SortShuffleManager的情况下其registerShuffle()方法如下:

// SortShuffleManager.scala
override def registerShuffle[K, V, C](
    shuffleId: Int,
    numMaps: Int,
    dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
  if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
    // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
    // need map-side aggregation, then write numPartitions files directly and just concatenate
    // them at the end. This avoids doing serialization and deserialization twice to merge
    // together the spilled files, which would happen with the normal code path. The downside is
    // having multiple files open at a time and thus more memory allocated to buffers.
    // 在不需要map端聚合、partition数量小于200的情况返回BypassMergeSortShuffleHandle对象
    new BypassMergeSortShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
      // 在不需要map端聚合、partition数量小于16777216,Serializer支持relocation的情况下
      // 使用SerializedShuffleHandle
    // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
    new SerializedShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else {
      // 其他情况
    // Otherwise, buffer map outputs in a deserialized form:
    new BaseShuffleHandle(shuffleId, numMaps, dependency)
  }
}

假设这一步得到的是BaseShuffleHandle对象。后面的源码分析基于SortShuffleWriter。

得到SortShuffleWriter对象之后再来看SortShuffleWriter的write()方法。该方法要写入的数据是由rdd.iterator(partition, context)得到的。所以先分析数据是如何生成的。
以以下代码为例进行分析。reduceBykey是一个shuffle操作。

lines.flatmap(_.split(" "))).map((_, 1)).reduceBykey(_ + _).count

首先从iterator()分析:

// RDD.scala
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
  if (storageLevel != StorageLevel.NONE) { // 缓存
    getOrCompute(split, context)
  } else { // 没有缓存
    computeOrReadCheckpoint(split, context)
  }
}

如何我们没有对该数据进行缓存,则会调用getOrCompute(),否则调用computeOrReadCheckpoint()。在该示例中没有缓存:

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
  if (isCheckpointedAndMaterialized) { // 被checkPoint
    firstParent[T].iterator(split, context)
  } else {
    compute(split, context) // 计算
  }
}

假设没有被checkpoint。由示例代码可知ShuffleMapStage的最后一个RDD是MapPartitionsRDD。则实际上是调用MapPartitionsRDD.compute()方法:

// MapPartitionsRDD.scala
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    var prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false,
    isFromBarrier: Boolean = false,
    isOrderSensitive: Boolean = false)
  extends RDD[U](prev) {
      // 注意f的第三个参数,也就是说此时会调用父RDD的iterator()函数得到父RDD的对于第index个partition的迭代器
      override def compute(split: Partition, context: TaskContext): Iterator[U] =
          f(context, split.index, firstParent[T].iterator(split, context)) // 1
          
  }

MapPartitionsRDD在RDD的map方法被调用的时候创建。以下代码用来说明f函数指针的取值。

// RDD.scala
def map[U: ClassTag](f: T => U): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  // iter是父RDD的对于第pid个Partition的迭代器
  // iter.map(cleanF)表示对父RDD的数据执行map装换,在该示例中就是转换成(_, 1)
  // 将父RDD迭代器中的每个元素x转换成(x, 1)
  new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}

知道要写入的数据是如何生成的之后继续往下分析。
现在来看write()方法
总结一下这个方法的主要逻辑:

  • 获取一个排序器,根据是否需要map端聚合传递不同的参数
  • 将数据插入排序器中,这个过程或溢写出多个磁盘文件
  • 根据shuffleid和分区id获取一个磁盘文件名
  • 将多个溢写的磁盘文件和内存中的排序数据进行归并排序,并写到一个文件中,同时返回每个reduce端分区的数据在这个文件中的位移
  • 将索引写入一个索引文件,并将数据文件的文件名由临时文件名改成正式的文件名。
  • 最后封装一个MapStatus对象,用于ShuffleMapTask.runTask的返回值。
  • 在stop方法中还会做一些收尾工作,统计磁盘io耗时,删除中间溢写文件
// SortShuffleWriter.scala
override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    // map端进行合并的情况,此时用户应该提供聚合器和顺序
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
    // care whether the keys get sorted in each partition; that will be done on the reduce side
    // if the operation being run is sortByKey.
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
  }
  // 将map数据全部写入排序器中,
  // 这个过程中可能会生成多个溢写文件
  sorter.insertAll(records) // 1

  // Don't bother including the time to open the merged output file in the shuffle write time,
  // because it just opens a single file, so is typically too fast to measure accurately
  // (see SPARK-3570).
  // mapId就是shuffleMap端RDD的partitionId
  // 获取这个map分区的shuffle输出文件名
  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  // 为输出文件名加一个uuid后缀
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    // 这一步将溢写到的磁盘的文件和内存中的数据进行归并排序,
    // 并溢写到一个文件中,这一步写的文件是临时文件名
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp) // 2
    // 这一步主要是写入索引文件,使用File.renameTo方法将临时索引和临时数据文件重命名为正常的文件名
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) //3
    // 返回一个状态对象,包含shuffle服务Id和各个分区数据在文件中的位移
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
    }
  }
}

写数据主要分为三个主要步骤:

  • 利用ExternalSorter.insertAll()方法将数据写到对应的数据结构。需要map端聚合的则写入PartitionedAppendOnlyMap,否则写入PartitionedPairBuffer。当缓存区满的时候会溢写到磁盘,所以这个过程可能会产生多个磁盘文件。
  • 利用ExternalSorter.writePartitionedFile()方法合并前面一个过程产生的多个磁盘文件以及缓存数据,产生一个临时数据文件。
  • 利用IndexShuffleBlockResolver.writeIndexFileAndCommit()方法写索引文件,并对临时文件重命名。

下面来详细分析这三个步骤。

1 首先分析ExternalSorter.insertAll方法。这个过程会生成多个溢写文件。

  • 首先根据是否在map端合并分为两种情况,这两种情况使用的内存存储结构也不一样,对于在map端合并的情况使用的是PartitionedAppendOnlyMap结构,不在map合并则使用PartitionedPairBuffer。其中,PartitionedAppendOnlyMap是用数组和线性探测法实现的map结构。
  • 然后将数据一条一条地循环插入内存的存储结构中,同时考虑到map端合并的情况。
// ExternalSorter.scala
@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  // TODO: stop combining if we find that the reduction factor isn't high
  val shouldCombine = aggregator.isDefined
  // 在map端进行合并的情况
  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    val mergeValue = aggregator.get.mergeValue
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    }
    
    while (records.hasNext) {
      addElementsRead() // _elementsRead属性值 + 1
      kv = records.next()
      // 向内存缓冲中插入一条数据
      // map的键为(partitionId, key)
      map.changeValue((getPartition(kv._1), kv._1), update) // 1
      // 如果缓冲超过阈值,就会溢写到磁盘生成一个文件
      // 每写入一条数据就检查一遍内存
      maybeSpillCollection(usingMap = true) // 2
    }
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      addElementsRead()
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)
    }
  }
}

在map端聚合的情况比不聚合的情况数据的插入要复杂写,涉及数据的更新,所以以下源码分析map端聚合的情况。不聚合的情况只需要将数据插入buffer数组即可同时考虑数组是否要进行扩容。

1.1 在map端聚合的情况下考虑数据的更新
在map端聚合的时候涉及changeValue()方法,其实是调用PartitionedAppendOnlyMap的父类SizeTrackingAppendOnlyMap的方法。然后再调用SizeTrackingAppendOnlyMap的父类AppendOnlyMap的changeValue()方法。

// SizeTrackingAppendOnlyMap.scala
override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
  val newValue = super.changeValue(key, updateFunc)
  super.afterUpdate()
  newValue
}

AppendOnlyMap的changeValue()方法其处理策略为:

  • 首先考虑空值的情况
  • 计算key的hash,然后对容量取余。注意,这里由于容量是2的整数次幂,所以对容量取余的操作等同于和容量-1进行按位与操作
  • 如果不存在旧值,那么直接插入
  • 如果存在旧值,更新旧值
  • 如果发生hash碰撞,那么需要向后探测,并且是跳跃性的探测
// AppendOnlyMap.scala
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
private var data = new Array[AnyRef](2 * capacity)
// 默认容量为64

private def incrementSize() {
  curSize += 1
  // 当当前表的数据量大于阈值则对表进行扩容
  if (curSize > growThreshold) {
    growTable()
  }
}

def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
  assert(!destroyed, destructionMessage)
  val k = key.asInstanceOf[AnyRef]
  if (k.eq(null)) {
      // 如果是第一次插入空值,那么需要将当前表的数据量增加1
    if (!haveNullValue) {
      // incrementSize方法中会判断当前数据量的大小,如果超过阈值就会扩容,这个扩容的方法比较复杂,
      // 就是一个重新hash再分布的过程,不过有一点,
      // 不论是在插入新数据还是重新hash再分布的过程中,
      // 对于hash碰撞的处理策略一定要相同,否则可能造成不一致。  
      incrementSize()
    }
    nullValue = updateFunc(haveNullValue, nullValue)
    haveNullValue = true
    return nullValue
  }
  // mask = capacity - 1
  var pos = rehash(k.hashCode) & mask
  // 线性探测法处理hash碰撞
  // 这里是一个加速的线性探测,即第一次碰撞时走1步,
  // 第二次碰撞时走2步,第三次碰撞时走3步
  var i = 1
  while (true) {
    // data就是hash表,其中的元素为:
    // key0,value0,key1,value1,etc.
    // 偶数位存放key,奇数位存放value
    val curKey = data(2 * pos)
    if (curKey.eq(null)) { // 如果旧值不存在,直接插入
      val newValue = updateFunc(false, null.asInstanceOf[V])
      data(2 * pos) = k
      data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
      incrementSize()
      return newValue
    } else if (k.eq(curKey) || k.equals(curKey)) {
      // key存在的情况 更新旧值  
      val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
      data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
      return newValue
    } else {
      // 表示发生了碰撞,向后探测,跳跃性的探测  
      val delta = i
      pos = (pos + delta) & mask
      i += 1
    }
  }
  null.asInstanceOf[V] // Never reached but needed to keep compiler happy
}

1.2 检查内存是否溢出,并在溢写的过程中对数据进行排序
调用maybeSpillCollection()方法检查缓存是否已满。每插入一条数据就要检查一次内存占用,判断是否需要溢写到磁盘,如果需要就溢写到磁盘。

// ExternalSorter.scala
private def maybeSpillCollection(usingMap: Boolean): Unit = {
  var estimatedSize = 0L
  if (usingMap) {
    // 估算当前插入的数据的内存占用大小
    estimatedSize = map.estimateSize()
    if (maybeSpill(map, estimatedSize)) {
        // 数据已经溢写到磁盘 创建新的AppendOnlyMap数据结构
      map = new PartitionedAppendOnlyMap[K, C]
    }
  } else {
    estimatedSize = buffer.estimateSize()
    if (maybeSpill(buffer, estimatedSize)) {
      buffer = new PartitionedPairBuffer[K, C]
    }
  }

  if (estimatedSize > _peakMemoryUsedBytes) {
    _peakMemoryUsedBytes = estimatedSize
  }
}

以上代码会调用ExternalSorter的父类Spillable.maybeSpill()。

// Spillable.scala
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
  var shouldSpill = false
  // 每写入32条数据检查一次
  // elementsRead是一个函数,其返回值为_elementsRead
  // protected def elementsRead: Int = _elementsRead
  if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
    // Claim up to double our current memory from the shuffle memory pool
    val amountToRequest = 2 * currentMemory - myMemoryThreshold
    // 向内存管理器申请执行内存
    // granted代表内存真正分配的 
    val granted = acquireMemory(amountToRequest)
    myMemoryThreshold += granted
    // If we were granted too little memory to grow further (either tryToAcquire returned 0,
    // or we already had more memory than myMemoryThreshold), spill the current collection
    // 如果内存占用超过了新的阈值,那么就需要溢写
    shouldSpill = currentMemory >= myMemoryThreshold
  }
  // 第二个条件表示当缓存的数据条数达到一定的量就进行溢写
  shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
  // Actually spill
  if (shouldSpill) {
    _spillCount += 1
    logSpillage(currentMemory)
    // 溢写到磁盘
    // 将当前缓存的数据都溢写到磁盘
    spill(collection)
    _elementsRead = 0
    _memoryBytesSpilled += currentMemory
    // 释放内存
    releaseMemory()
  }
  shouldSpill
}

将数据溢写到磁盘的过程调用spill()函数。

// ExternalSorter.scala
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
  // 获取一个排序后的迭代器
  // 这个方法返回一个按照分区和key排序过的迭代器
  val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
  // 将数据写入磁盘文件中
  val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
  spills += spillFile
}

private def comparator: Option[Comparator[K]] = {
    // 在map端聚合的情况下,会设置聚合器
  if (ordering.isDefined || aggregator.isDefined) {
    Some(keyComparator) // 返回默认的key比较器
  } else {
    None
  }
}

Map端聚合的情况:
WritablePartitionedPairCollection.destructiveSortedWritablePartitionedIterator这个方法返回按照分区和key排序过的迭代器,其具体的排序逻辑在AppendOnlyMap.destructiveSortedIterator中。
具体调用步骤是(借助前面的类图分析):

  • 在该方法中调用PartitionedAppendOnlyMap(WritablePartitionedPairCollection是其父类)的partitionedDestructiveSortedIterator()方法。
// WritablePartitionedPairCollection.scala
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
  : WritablePartitionedIterator = {
  val it = partitionedDestructiveSortedIterator(keyComparator) //1 
  new WritablePartitionedIterator {
    private[this] var cur = if (it.hasNext) it.next() else null

    def writeNext(writer: DiskBlockObjectWriter): Unit = {
      writer.write(cur._1._2, cur._2)
      cur = if (it.hasNext) it.next() else null
    }

    def hasNext(): Boolean = cur != null

    def nextPartition(): Int = cur._1._1
  }
}
// PartitionedAppendOnlyMap.scala
def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
  : Iterator[((Int, K), V)] = {
  // 将keyComparator封装成一个新的Comparator。该比较器先按partitionId排序,再按照key排序。
  val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
  destructiveSortedIterator(comparator) // 1
}
  • 然后又调用PartitionedAppendOnlyMap的父类AppendOnlyMap的destructiveSortedIterator()
    这段代码分为两块,首先对数组进行压紧,将稀疏的数据(因为之前数据是根据hash值进行存储的)全部转移到数组的头部。然后对数组按照比较器进行排序,比较器首先按照分区进行比较,如果分区相同才按照key进行比较。然后返回一个迭代器,这个迭代器仅仅是对数组的封装。shuffle过程排序的逻辑如下!!!
// AppendOnlyMap.scala
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
  destroyed = true
  // Pack KV pairs into the front of the underlying array
  var keyIndex, newIndex = 0
  while (keyIndex < capacity) {
    if (data(2 * keyIndex) != null) {
      data(2 * newIndex) = data(2 * keyIndex)
      data(2 * newIndex + 1) = data(2 * keyIndex + 1)
      newIndex += 1
    }
    keyIndex += 1
  }
  assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
  // 根据比较器对数据进行排序
  new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator) // 1

  new Iterator[(K, V)] {
    var i = 0
    var nullValueReady = haveNullValue
    def hasNext: Boolean = (i < newIndex || nullValueReady)
    def next(): (K, V) = {
      if (nullValueReady) {
        nullValueReady = false
        (null.asInstanceOf[K], nullValue)
      } else {
        val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
        i += 1
        item
      }
    }
  }
}

用到的排序比较器,是在partitionKeyComparator方法中创建的匿名比较器:

// WritablePartitionedPairCollection.scala
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
  new Comparator[(Int, K)] {
      // a的格式为(partitionId, key)
    override def compare(a: (Int, K), b: (Int, K)): Int = {
      val partitionDiff = a._1 - b._1
      if (partitionDiff != 0) {
        partitionDiff
      } else {
        keyComparator.compare(a._2, b._2)
      }
    }
  }
}

再将数据进行排序后就可以溢写到磁盘了。

Map端无聚合的情况:
WritablePartitionedPairCollection.destructiveSortedWritablePartitionedIterator这个方法返回按照分区排序过的迭代器,其具体的排序逻辑在PartitionedPairBuffer.destructiveSortedIterator中。

override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
  : Iterator[((Int, K), V)] = {
  // 按照分区排序
  // keyComparator.map(partitionKeyComparator)为空,所以使用partitionComparator
  val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
  new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
  iterator
}

附上述代码涉及到的scala Option的用法:

def map[B](f: (A) => B): Option[B]
如果选项包含有值, 返回由函数 f 处理后的 Some,否则返回 None
def getOrElse[B >: A](default: => B): B
如果选项包含有值,返回选项值,否则返回设定的默认值。

在map端无聚合的情况下用的比较器是partitionComparator。其比较逻辑如下:

def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
  override def compare(a: (Int, K), b: (Int, K)): Int = {
    a._1 - b._1
  }
}

所以下map端聚合的情况下,中间数据文件会按照先分区排序,分区内数据按照key排序。而map端无聚合的情况下,按照分区排序。
总结一下数据通过ExternalSorter向磁盘溢写的全过程:

  • 首先,数据会被一条一条地向内部的map结构中插入。
  • 每插入一条数据都会检查内存占用情况,如果内存占用超过阈值,并且申请不到足够的执行内存,就会将目前内存中的数据溢写到磁盘。
  • 对于溢写的过程:首先会将数据按照分区和key进行排序,相同分区的数据排在一起,然后根据提供的排序器按照key的顺序排;然后通过DiskBlockManager和BlockManager获取DiskBlockWriter将数据写入磁盘形成一个文件。
  • 在整个写入过程中,会溢写多个文件。

2 对溢写的文件和内存数据进行合并
利用ExternalSorter.writePartitionedFile()方法将溢写的多个数据文件和缓存数据进行合并。
总结一下主要的步骤:

  • 仍然是通过blockManager获取一个磁盘写入器
  • 将内部溢写的多个磁盘文件和滞留在内存的数据进行归并排序,并封装成一个按照分区归类的迭代器
  • 循环将数据写入磁盘,每当一个分区的数据写完后,进行一次刷写,将数据从os的文件缓冲区同步到磁盘上,然后获取此时的文件长度,记录下每个分区在文件中的位移
// ExternalSorter.scala
def writePartitionedFile(
    blockId: BlockId,
    outputFile: File): Array[Long] = {

  // Track location of each range in the output file
  val lengths = new Array[Long](numPartitions)
  val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
    context.taskMetrics().shuffleWriteMetrics)
  // 如果前面没有数据溢写到磁盘中,
  // 则只需要将内存中的数据溢写到磁盘
  if (spills.isEmpty) {
    // Case where we only have in-memory data
    val collection = if (aggregator.isDefined) map else buffer
    // 返回排序后的迭代器
    val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
    while (it.hasNext) {
      val partitionId = it.nextPartition()
      while (it.hasNext && it.nextPartition() == partitionId) {
        it.writeNext(writer)
      }
       // 写完一个分区刷写一次
      val segment = writer.commitAndGet()
      // 记录下分区的数据在文件中的位移
      lengths(partitionId) = segment.length
    }
  } else { // 有溢写到磁盘的文件
    // We must perform merge-sort; get an iterator by partition and write everything directly.
    // 封装一个用于归并各个溢写文件以及内存缓冲区数据的迭代器
    // 这个封装的迭代器是实现归并排序的关键
    // id为partitionId,elements为该partition对应的迭代器
    for ((id, elements) <- this.partitionedIterator) { // 1
      if (elements.hasNext) {
        for (elem <- elements) {
          writer.write(elem._1, elem._2)
        }
         // 每写完一个分区,主动刷写一次,获取文件位移,
        // 这个位移就是写入的分区的位移,
        // reduce端在拉取数据时就会根据这个位移直接找到应该拉取的数据的位置
        val segment = writer.commitAndGet()
        lengths(id) = segment.length
      }
    }
  }
  // 写完后更新一些统计信息
  writer.close()
  context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
  context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
  context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
  // 返回每个reduce端分区数据在文件中的位移信息
  lengths
}

返回归并后的迭代器的方法为partitionedIterator,其源码为:

// ExternalSorter.scala
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
  val usingMap = aggregator.isDefined
  val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
  if (spills.isEmpty) {  // 如果前面没有数据溢写到磁盘中
    if (!ordering.isDefined) {
      // The user hasn't requested sorted keys, so only sort by partition ID, not key
      // 只需要根据partitionId排序
      groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
    } else {
      // We do need to sort by both partition ID and key
      // 需要根据partitionId和key排序
      groupByPartition(destructiveIterator(
        collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
    }
  } else {
    // Merge spilled and in-memory data
    // 合并溢写的数据和内存数据
    merge(spills, destructiveIterator(
      collection.partitionedDestructiveSortedIterator(comparator))) // 1
  }
}

对溢写的文件和内存数据进行合并

// ExternalSorter.scala
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
    : Iterator[(Int, Iterator[Product2[K, C]])] = {
   // spills中的每一项都是溢写文件名     
  val readers = spills.map(new SpillReader(_))
  val inMemBuffered = inMemory.buffered
  (0 until numPartitions).iterator.map { p =>
    //  将内存数据中属于partition p的数据放入 inMemIterator 中
    val inMemIterator = new IteratorForPartition(p, inMemBuffered)
    // 合并溢写文件中所有属于partition p的数据和inMemIterator 的数据
    val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
    if (aggregator.isDefined) {
      // Perform partial aggregation across partitions
      // 聚合再排序
      (p, mergeWithAggregation(
        iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
    } else if (ordering.isDefined) {
      // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
      // sort the elements without trying to merge them
      // 只要排序
      (p, mergeSort(iterators, ordering.get))
    } else {
        // 不用排序
      (p, iterators.iterator.flatten)
    }
  }
}

3 写索引文件
这个方法的作用主要是将每个的分区的位移值写入到一个索引文件中,并且在这个过程会将之前临时的数据文件重命名为正式数据文件。

// IndexShuffleBlockResolver.scala
def writeIndexFileAndCommit(
    shuffleId: Int,
    mapId: Int,
    lengths: Array[Long],
    dataTmp: File): Unit = {
  val indexFile = getIndexFile(shuffleId, mapId)
  // 索引临时文件
  val indexTmp = Utils.tempFileWith(indexFile)
  try {
    val dataFile = getDataFile(shuffleId, mapId)
    // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
    // the following check and rename are atomic.
    synchronized {
      val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
      if (existingLengths != null) {
          // 如果existingLengths存在直接删除临时文件即可 
        // Another attempt for the same task has already written our map outputs successfully,
        // so just use the existing partition lengths and delete our temporary map outputs.
        System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
        if (dataTmp != null && dataTmp.exists()) {
          dataTmp.delete()
        }
      } else {
        // This is the first successful attempt in writing the map outputs for this task,
        // so override any existing index and data files with the ones we wrote.
        // 写索引文件。
        val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
        Utils.tryWithSafeFinally {
          // We take in lengths of each block, need to convert it to offsets.
          var offset = 0L
          out.writeLong(offset)
          for (length <- lengths) {
            offset += length
            out.writeLong(offset)
          }
        } {
          out.close()
        }

        if (indexFile.exists()) {
          indexFile.delete()
        }
        if (dataFile.exists()) {
          dataFile.delete()
        }
        // 对索引文件重命名
        if (!indexTmp.renameTo(indexFile)) {
          throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
        }
        // 对数据文件重命名
        if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
          throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
        }
      }
    }
  } finally {
    if (indexTmp.exists() && !indexTmp.delete()) {
      logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
    }
  }
}

到此,Spark Shuffle的写过程就结束了。下一篇分析其读过程。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值