Spark源码阅读(一): ShuffleWriter解析

Spark的Shuffle过程比较复杂,对于源码也是看过多次记不住。简单整理一下,不会太深入每个源码的细节。大概梳理shuffle的过程,持续优化内容。

以join为例:

val rdd = rdd1.join(rdd2)

// 以下是join方法
def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope {
  join(other, defaultPartitioner(self, other))
}

// defaultPartitioner方法,获取默认的Partitioner
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
  val rdds = (Seq(rdd) ++ others)
  val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
  if (hasPartitioner.nonEmpty) {
    hasPartitioner.maxBy(_.partitions.length).partitioner.get
  } else {
    if (rdd.context.conf.contains("spark.default.parallelism")) {
      new HashPartitioner(rdd.context.defaultParallelism)
    } else {
      new HashPartitioner(rdds.map(_.partitions.length).max)
    }
  }
}

调用Partitioner中的defaultPartitioner(RDD[_], RDD[_]*)方法获取partitioner。获取分区器的逻辑很简单就是如果有RDD存在Partitoner那么就从这些存在的Partitioner中取分区数最大的那个;如果没有RDD存在分区器则先看spark.default.parallelism参数是否有配置,一般默认是200,如果该参数有配置则创建一个分区数为该参数的HashPartitioner,如果参数不存在则创建要给分区数为所有RDD中最大分区数的HashPartitioner。

继续看join方法

def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope {
  this.cogroup(other, partitioner).flatMapValues( pair =>
    for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w)
  )
}

// cogroup方法主要是创建了CoGroupedRDD,然后最终返回一个MapPartitionsRDD
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
    : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope {
  if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
    throw new SparkException("HashPartitioner cannot partition array keys.")
  }
  val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
  cg.mapValues { case Array(vs, w1s) =>
    (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]])
  }
}

调用cogroup方法首先校验key类型和Partitioner类型,HashPartitioner没法对Array类型的key做分区。

  1. 然后创建了一个CoGroupedRDD:val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)

CoGroupedRDD中查看getDependencies方法,看创建的CoGroupedRDD所依赖的Dependency是哪种类型的Dependency:

// CoGroupedRDD的getDependencies方法,看CoGroupedRDD的Denpendcy类型是如何确定的
override def getDependencies: Seq[Dependency[_]] = {
  rdds.map { rdd: RDD[_] =>
    if (rdd.partitioner == Some(part)) {
      logDebug("Adding one-to-one dependency with " + rdd)
      new OneToOneDependency(rdd)
    } else {
      logDebug("Adding shuffle dependency with " + rdd)
      new ShuffleDependency[K, Any, CoGroupCombiner](
        rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer)
    }
  }
}

可以看到,如果某个RDD的Partitioner跟创建CoGroupedRDD时传入的Partitioner一样,会创建OneToOneDependency依赖,如果不一样(RDD没有Partitioner或Partitioner与传入的Partitioner不一样)就创建一个ShuffleDependency。因此,比如两个RDD做连接均设置了相同的Partitioner,那么可以推测他们之间的连接是不会存在Shuffler过程的。在具体的基于RDD的Spark任务编写的时候可以考虑这点,提高一些性能。

  1. 创建了CoGroupedRDD后调用MapValues方法,在mapValues方法中又创建了一个MapPartitonsRDD。

def mapValues[U](f: V => U): RDD[(K, U)] = self.withScope {
  val cleanF = self.context.clean(f)
  new MapPartitionsRDD[(K, U), (K, V)](self,
    (context, pid, iter) => iter.map { case (k, v) => (k, cleanF(v)) },
    preservesPartitioning = true)
}

到这里cogroup方法执行结束,得到了一个MapPartitonsRDD,然后又调用了flatMapValues方法,得到最终的RDD。进入flaMapValues方法可以看到又创建了一个新的MapPartitionsRDD。所以最后得到的是一个MapPartitionsRDD。

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    var prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false)
  extends RDD[U](prev)

// RDD中的构造方法
def this(@transient oneParent: RDD[_]) =
  this(oneParent.context, List(new OneToOneDependency(oneParent)))

可以看到创建MapPrtitionsRDD对于parent的依赖都是OneToOneDependency。

所以整个过程为:

rdd1.join(rdd2)的过程

以上流程假设rdd1和rdd2均没有Partitioner,如果rdd1的分区数要大一些则会按rdd1的分区数创建HashParitioner,。所以后续的RDD均延续了rdd3的Partitioner。rdd3的Dependency中对于rdd1,rdd2的依赖类型都是ShuffleDependency。如果对rdd5做collect可以遇见会分为3个stage。

Shuffler过程

前面部分只是铺垫,现在进入正题。跳过一些中间过程(内容太多),直接看DAGScheduler的handleJobSubmitted方法。

// DAGScheduler的handleJobSubmitted创建stage,并提交stage
private[scheduler] def handleJobSubmitted(jobId: Int,
    finalRDD: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    callSite: CallSite, 
    listener: JobListener,
    properties: Properties) {
  var finalStage: ResultStage = null
  try {
    // New stage creation may throw an exception if, for example, jobs are run on a
    // HadoopRDD whose underlying HDFS files have been deleted.
    finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
  } catch {
    case e: Exception =>
      logWarning("Creating new stage failed due to exception - job: " + jobId, e)
      listener.jobFailed(e)
      return
  }
  // ........省略部分代码
  submitStage(finalStage)
}

首先创建finalStage,调用DAGScheduler的createResultStage方法。Stage有两种类型,ShuffleMapStage和ResultStage。

private def createResultStage(
    rdd: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    jobId: Int,
    callSite: CallSite): ResultStage = {
  val parents = getOrCreateParentStages(rdd, jobId)
  val id = nextStageId.getAndIncrement()
  val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
  stageIdToStage(id) = stage
  updateJobIdStageIdMaps(jobId, stage)
  stage
}

最终的stage肯定是一个ResultStage,在创建最终的ResultStage时需要找到其parentStage。根据最终的rdd,调用getOrCrateParentStages得到父Stage。

private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
  getShuffleDependencies(rdd).map { shuffleDep =>
    getOrCreateShuffleMapStage(shuffleDep, firstJobId)
  }.toList
}

创建父Stage的依据就是ShuffleDependency,所以需要先找到最终rdd的所有shuffle依赖。然后根据每个ShuffleDependency创建ShuflleMapStage作为最终finalStage的父Stage。

private[scheduler] def getShuffleDependencies(
    rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
  val parents = new HashSet[ShuffleDependency[_, _, _]]
  val visited = new HashSet[RDD[_]]
  val waitingForVisit = new Stack[RDD[_]]
  waitingForVisit.push(rdd)
  while (waitingForVisit.nonEmpty) {
    val toVisit = waitingForVisit.pop()
    if (!visited(toVisit)) {
      visited += toVisit
      toVisit.dependencies.foreach {
        case shuffleDep: ShuffleDependency[_, _, _] =>
          parents += shuffleDep
        case dependency =>
          waitingForVisit.push(dependency.rdd)
      }
    }
  }
  parents
}

整个寻找Shuffle依赖的过程是一个栈处理,如果遇到shuffleDependency就加入到parents,如果是遇到非ShuflleDependency就加入到待处理栈。由此可以看出多个非shuffle依赖的rdd处理会划分到同一个Stage中。

从上面的代码可以看出,找到shuffle依赖后不会再继续网上找,可以理解为往前找到的第一层shuffle依赖。

然后继续回到getOrCreateParentStages方法,会对每个找到的shuffle依赖创建ShuffleMapStage,通过方法getOrCreateShuffleMapStage

private def getOrCreateShuffleMapStage(
    shuffleDep: ShuffleDependency[_, _, _],
    firstJobId: Int): ShuffleMapStage = {
  shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
    case Some(stage) =>
      stage
    case None =>
      getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
        if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
          createShuffleMapStage(dep, firstJobId)
        }
      }
      // Finally, create a stage for the given shuffle dependency.
      createShuffleMapStage(shuffleDep, firstJobId)
  }
}

首先会继续找当前shuffle依赖的rdd对应的父shuffle依赖,并对每个shuffle依赖创建ShuffleMapStage。

private def getMissingAncestorShuffleDependencies(
    rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
  val ancestors = new Stack[ShuffleDependency[_, _, _]]
  val visited = new HashSet[RDD[_]]
  // We are manually maintaining a stack here to prevent StackOverflowError
  // caused by recursively visiting
  val waitingForVisit = new Stack[RDD[_]]
  waitingForVisit.push(rdd)
  while (waitingForVisit.nonEmpty) {
    val toVisit = waitingForVisit.pop()
    if (!visited(toVisit)) {
      visited += toVisit
      getShuffleDependencies(toVisit).foreach { shuffleDep =>
        if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) {
          ancestors.push(shuffleDep)
          waitingForVisit.push(shuffleDep.rdd)
        } // Otherwise, the dependency and its ancestors have already been registered.
      }
    }
  }
  ancestors
}

在getMissingAncestorShuffleDependencies中会穷尽所有层级的suffle依赖。以上整个就创建完了所有Stage。

注意以上Shuffle依赖在Stage划分中的作用。回到handleJobSubmitted方法,最后会执行submitStage方法,提交finalStage。

private def submitStage(stage: Stage) {
  val jobId = activeJobForStage(stage)
  if (jobId.isDefined) {
    logDebug("submitStage(" + stage + ")")
    if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
      val missing = getMissingParentStages(stage).sortBy(_.id)
      logDebug("missing: " + missing)
      if (missing.isEmpty) {
        logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
        submitMissingTasks(stage, jobId.get)
      } else {
        for (parent <- missing) {
          submitStage(parent)
        }
        waitingStages += stage
      }
    }
  } else {
    abortStage(stage, "No active job for stage " + stage.id, None)
  }
}

提交finalStage,在方法中会先找到未执行的父stage,先提交父stage,这里是要给递归调用,所以虽然是从最后一个stage开始提交,但是会从第一个stage开始提交任务,执行submitMissingTasks。任务提交的过程后面单独整理。针对ShuffleMapStage创建ShuffleMapTask,针对ResultStage创建ResultTask。

执行ShuffleMapTask

通过上面的分析,已经提交了最终的不同Stage的task了,针对ShuffleMapTask其runTask方法如下:

override def runTask(context: TaskContext): MapStatus = {
  // Deserialize the RDD using the broadcast variable.
  val threadMXBean = ManagementFactory.getThreadMXBean
  val deserializeStartTime = System.currentTimeMillis()
  val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
  } else 0L
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
  _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
  } else 0L


  var writer: ShuffleWriter[Any, Any] = null
  try {
    val manager = SparkEnv.get.shuffleManager
    writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    writer.stop(success = true).get
  } catch {
    case e: Exception =>
      try {
        if (writer != null) {
          writer.stop(success = false)
        }
      } catch {
        case e: Exception =>
          log.debug("Could not stop writer", e)
      }
      throw e
  }
}

首先从SparkEnv中获取ShuffleManager,目前Spark都默认改为SortShuffleMananger。然后执行ShuffleManager的getWriter方法。

override def getWriter[K, V](
    handle: ShuffleHandle,
    mapId: Int,
    context: TaskContext): ShuffleWriter[K, V] = {
  numMapsForShuffle.putIfAbsent(
    handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
  val env = SparkEnv.get
  handle match {
    case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
      new UnsafeShuffleWriter(
        env.blockManager,
        shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
        context.taskMemoryManager(),
        unsafeShuffleHandle,
        mapId,
        context,
        env.conf)
    case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
      new BypassMergeSortShuffleWriter(
        env.blockManager,
        shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
        bypassMergeSortHandle,
        mapId,
        context,
        env.conf)
    case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
      new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
  }
}

创建什么类型的ShuffleWriter是根据不同的ShuffleHandle来判断的。有三种类型的ShuffleHandle

  • BaseShuffleHandle

  • BypassMergeSortShuffleHandle

  • SerializeShuffleHandle

这些ShuffleHandle是从dependency中获取的。在Denpendency中shuffleHandle的定义是

val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
  shuffleId, _rdd.partitions.length, this)

所以是调用的ShuffleManage的registerShuffle方法,再看registerShuffle方法

override def registerShuffle[K, V, C](
    shuffleId: Int,
    numMaps: Int,
    dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
  if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
    new BypassMergeSortShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
    new SerializedShuffleHandle[K, V](
      shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
  } else {
    new BaseShuffleHandle(shuffleId, numMaps, dependency)
  }

先判断是否适用ByPassMergeSort:SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)。满足ByPassMergeSort的条件是:

              (1)dependency的mspSideCombine=false,也就是不存在mapSide的组合操作

              (2)分区个数要小于等于spark.shuffle.sort.bypassMergeThreshold阈值,默认是200

判断是否满足SerializeShuffle条件:SortShuffleManager.canUseSerializedShuffle(dependency),判断是否适用SerializeShuffle的条件如下:

               (1)serializer序列化器必须是支持relocation

               (2)不存在aggregator聚合器

               (3)分区数不能大于MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE

如果不满足以上两种ShuffleHandle则适用于BaseShuffleHandle

ShuffleWriter

BypassMergeSortShuffleWriter

BypassMergeSortShuffleWriter(
    BlockManager blockManager,
    IndexShuffleBlockResolver shuffleBlockResolver,
    BypassMergeSortShuffleHandle<K, V> handle,
    int mapId,
    TaskContext taskContext,
    SparkConf conf)

主要的属性有:

  • blockManager: BlockManager

  • shuffleBlockResolver: IndexShuffleBlockResolver

  • handle: BypassMergeSortShuffleHandle

  • mapId: Int

  • taskContext: TaskContext

  • conf: SparkConf

write方法如下,参数Iterator针对的是每一个Partition,write方法是在ShuffleMapTask中的runTask方法中调用:

public void write(Iterator<Product2<K, V>> records) throws IOException {
  assert (partitionWriters == null);
  if (!records.hasNext()) {
    partitionLengths = new long[numPartitions];
    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
    return;
  }
  final SerializerInstance serInstance = serializer.newInstance();
  final long openStartTime = System.nanoTime();
  partitionWriters = new DiskBlockObjectWriter[numPartitions];
  partitionWriterSegments = new FileSegment[numPartitions];
  for (int i = 0; i < numPartitions; i++) {
    final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
      blockManager.diskBlockManager().createTempShuffleBlock();
    final File file = tempShuffleBlockIdPlusFile._2();
    final BlockId blockId = tempShuffleBlockIdPlusFile._1();
    partitionWriters[i] =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
  }
  // Creating the file to write to and creating a disk writer both involve interacting with
  // the disk, and can take a long time in aggregate when we open many files, so should be
  // included in the shuffle write time.
  writeMetrics.incWriteTime(System.nanoTime() - openStartTime);


  while (records.hasNext()) {
    final Product2<K, V> record = records.next();
    final K key = record._1();
    partitionWriters[partitioner.getPartition(key)].write(key, record._2());
  }


  for (int i = 0; i < numPartitions; i++) {
    final DiskBlockObjectWriter writer = partitionWriters[i];
    partitionWriterSegments[i] = writer.commitAndGet();
    writer.close();
  }


  File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
  File tmp = Utils.tempFileWith(output);
  try {
    partitionLengths = writePartitionedFile(tmp);
    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
    }
  }
  mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

BypassMergeSortShuffleWriter的大概过程如下:

  • 根据依赖的partitioner的numPartitions分区个数创建对应多个DiskBlockObjectWriter,每个DiskBlockObjectWriter对应一个file,file名称是"temp_shuffle_%UUID%"

  • 遍历写入每一条记录,根据每条记录的Key得到对应的DiskBlockObjectWriter,然后执行写入,partitionWriters[partitioner.getPartition(key)].write(key, record._2())。

  • 刷新并提交每一个DiskBlockObjectWriter,得到对应的FileSegment对象,并执行close方法关闭每一个writer。

  • 创建一个当前shuffle阶段的输出文件shuffle_%shuffleId%_%mapId%_%reduceId%.data,执行partitionLengths = writePartitionedFile(tmp);将前序生成的每个目标partition的数据文件数据合并写入到最终的data数据文件。

  • 写入index文件,shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);其中partitionLengths记录了之前每个partition的数据长度,这样就可以根据偏移量从index文件中找到对应partiiton的数据。index文件的名称为shuffle_%shuffleId%_%mapId%_%reduceId%.index

总体看来,BypassMergeSortShuffleWriter是先针对每个目标Partition(reduce对应的partition)生成一个数据文件,比如原RDD是500分区,目标parition数为150个,则会生成150个文件,然后将150个文件合并为一个文件,并同时生成一个index文件,所以最后的文件个数是2 * ShuffleMapTask个数个文件。这样做的好处是降低了小文件的个数,小文件对于大数据处理来说是灾难,这也是Spark舍弃HashShuffle的原因。

此外,BypassMerger方式不会对数据进行排序,因为在一些任务中Sort过程可能没有必要或者非常消耗性能,如果在确定自己的任务不需要排序的话就可以让Shuffle过程用Bypass模式,通过调整spark.shuffle.sort.bypassMergeThreshold参数来让自己的任务走Bypass方式。

 

SortShuffleWriter

SortShuffleWriter的核心是通过ExteralSorter来进行排序和数据溢写spill。

override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
  }
  sorter.insertAll(records)

  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
    }
  }
}

对于ExternalSorter,比较重要的属性包括:

  • context: TaskContext

  • aggregator: Option[Aggregator[K,V,C]]

  • partitioner: Option[Partitioner]

  • ordering: Option[Ordering[K]]

  • serializer: Serializer

  • map: PartitionedAppendOnlyMap, 在存在aggregator时用于内存缓存聚合后的对象(记录)。

  • buffer: PartitionedPairBuffer, 不存在aggregator时用于缓存内存对象(记录)

后面用单独再整理ParitionedAppendOnlyMap和PartitionedPairBuffer。

SortShuflleWriter的写入大致过程如下。

  • 创建ExternalSorter

  • 调用ExternalSorter的insertAll方法处理每一条记录

  • 合并溢写的数据文件并生成index文件

注意,SortShuffleWriter是有排序过程的。先看ExternalSorter的insertAll方法

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  val shouldCombine = aggregator.isDefined


  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    val mergeValue = aggregator.get.mergeValue
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    }
    while (records.hasNext) {
      addElementsRead()
      kv = records.next()
      map.changeValue((getPartition(kv._1), kv._1), update)
      maybeSpillCollection(usingMap = true)
    }
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      addElementsRead()
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)
    }
  }
}

将记录处理后存入map或者buffer中,然后会执行maybeSpillCollection(usingMap)判断是否需要溢写。

private def maybeSpillCollection(usingMap: Boolean): Unit = {
  var estimatedSize = 0L
  if (usingMap) {
    estimatedSize = map.estimateSize()
    if (maybeSpill(map, estimatedSize)) {
      map = new PartitionedAppendOnlyMap[K, C]
    }
  } else {
    estimatedSize = buffer.estimateSize()
    if (maybeSpill(buffer, estimatedSize)) {
      buffer = new PartitionedPairBuffer[K, C]
    }
  }


  if (estimatedSize > _peakMemoryUsedBytes) {
    _peakMemoryUsedBytes = estimatedSize
  }
}

再看maybeSpill方法,会基于当前的内存使用情况和当前内存中的记录数综合看是否需要溢写,如果需要溢写则执行溢写操作。

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
  var shouldSpill = false
  if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
    val amountToRequest = 2 * currentMemory - myMemoryThreshold
    val granted = acquireMemory(amountToRequest)
    myMemoryThreshold += granted
    shouldSpill = currentMemory >= myMemoryThreshold
  }
  shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
  // Actually spill
  if (shouldSpill) {
    _spillCount += 1
    logSpillage(currentMemory)
    spill(collection)
    _elementsRead = 0
    _memoryBytesSpilled += currentMemory
    releaseMemory()
  }
  shouldSpill
}
  • 首先判当前内存使用量是否已经大于等于内存阈值。阈值通过spark.shuffle.spill.initialMemoryThreshold设置,默认值时5M。但是过程并不是说大于5M便会生成一个溢写文件。而是会先尝试申请2 * currentMemory - myMemoryThreshold这么多的内存,调用acquireMemory申请内存,但并不一定能够申请到足够的内存。再用当前申请到的内存加上当前的内存阈值得到新的阈值再跟当前的内存使用量进行比较,如果新阈值还是小于等于当前内存使用量则需要溢写。举个例子,当前阈值为初始值5M,当前内存使用量为17M,17大于5,于是申请扩容2 * 17 - 5 = 29M的内存,如果申请到则最新阈值变为29 + 5 = 34M,那么就不需要溢写。如果申请扩容返回得到10M,那么新阈值为10 + 5 = 15M,就需要溢写。

  • 再结合当前已经处理的记录条数判断是否需要溢写,阈值由参数spark.shuffle.spill.numElementsForceSpillThreshold设置,默认时Long。MaxValue。如果缓存中的数据量达到这个阈值则会强制溢写。

  • 如果需要溢写,则执行spill方法。

spill方法如下:

// ExternalSorter.scala

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
  val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
  val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
  spills += spillFile
}
  • 先执行destructiveSortedWritablePartitionedIterator对缓存中的数据进行排序

// WritablePartitionedPairCollection.scala

/**
* Iterate through the data and write out the elements instead of returning them. Records are
* returned in order of their partition ID and then the given comparator.
* This may destroy the underlying collection.
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
  : WritablePartitionedIterator = {
  val it = partitionedDestructiveSortedIterator(keyComparator)
  new WritablePartitionedIterator {
    private[this] var cur = if (it.hasNext) it.next() else null
    def writeNext(writer: DiskBlockObjectWriter): Unit = {
      writer.write(cur._1._2, cur._2)
      cur = if (it.hasNext) it.next() else null
    }
    def hasNext(): Boolean = cur != null
    def nextPartition(): Int = cur._1._1
  }
}
  • 先执行partitionedDestructSortedIterator, 将数据按指定顺序构建Iterator。

先看PartitionedPairBuffer中的partitionedDestructiveSortedIterator

// PartitonedPairBuffer.scala
/** Iterate through the data in a given order. For this class this is not really destructive. */
override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
  : Iterator[((Int, K), V)] = {
  val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
  new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
  iterator
}

首先,得到comparator;然后new一个Sorter对数据进行排序。

其中比较器compartor的规则是如果传入了key的比较器则先按照parititonID排序再按传入的比较器对key进行排序;如果没有传入比较器则只按partititonID进行排序。

接下来是PartitionedAppendOnlyMap中的partitionedDestructiveSortedIterator

def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
  : Iterator[((Int, K), V)] = {
  val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
  destructiveSortedIterator(comparator)
}

首先是构造比较器,逻辑同前面的比较器构造逻辑。接下来是调用AppendOnlyMap(PartitionedAppendOnlyMap的父类)中的destructiveSortedIterator方法。

// AppendOnlyMap.scala

/**
* Return an iterator of the map in sorted order. This provides a way to sort the map without
* using additional memory, at the expense of destroying the validity of the map.
*/
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
  destroyed = true
  // Pack KV pairs into the front of the underlying array
  var keyIndex, newIndex = 0
  while (keyIndex < capacity) {
    if (data(2 * keyIndex) != null) {
      data(2 * newIndex) = data(2 * keyIndex)
      data(2 * newIndex + 1) = data(2 * keyIndex + 1)
      newIndex += 1
    }
    keyIndex += 1
  }
  assert(curSize == newIndex + (if (haveNullValue) 1 else 0))


  new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)

  new Iterator[(K, V)] {
    var i = 0
    var nullValueReady = haveNullValue
    def hasNext: Boolean = (i < newIndex || nullValueReady)
    def next(): (K, V) = {
      if (nullValueReady) {
        nullValueReady = false
        (null.asInstanceOf[K], nullValue)
      } else {
        val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
        i += 1
        item
      }
    }
  }
}

首先,将KV键值对挪到data也就是底层存储数组的前端(舍弃前面的null);

然后,new一个Sorter对数据进行排序。

从上面可以看到不管是PartitionedPairBuffer还是PartitionedAppendOnlyMap底层都是通过Array来存储数据,key和value紧邻,key索引为偶数,value索引为key索引+1.

继续回到ExternalSorter的spill方法。对记录数据进行排序后(内存中),然后就是执行将内存数据溢写到磁盘了。val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)。

/**
* Spill contents of in-memory iterator to a temporary file on disk.
*/
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
    : SpilledFile = {
  // Because these files may be read during shuffle, their compression must be controlled by
  // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
  // createTempShuffleBlock here; see SPARK-3426 for more context.
  val (blockId, file) = diskBlockManager.createTempShuffleBlock()


  // These variables are reset after each flush
  var objectsWritten: Long = 0
  val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
  val writer: DiskBlockObjectWriter =
    blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)


  // List of batch sizes (bytes) in the order they are written to disk
  val batchSizes = new ArrayBuffer[Long]


  // How many elements we have in each partition
  val elementsPerPartition = new Array[Long](numPartitions)


  // Flush the disk writer's contents to disk, and update relevant variables.
  // The writer is committed at the end of this process.
  def flush(): Unit = {
    val segment = writer.commitAndGet()
    batchSizes += segment.length
    _diskBytesSpilled += segment.length
    objectsWritten = 0
  }


  var success = false
  try {
    while (inMemoryIterator.hasNext) {
      val partitionId = inMemoryIterator.nextPartition()
      require(partitionId >= 0 && partitionId < numPartitions,
        s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
      inMemoryIterator.writeNext(writer)
      elementsPerPartition(partitionId) += 1
      objectsWritten += 1


      if (objectsWritten == serializerBatchSize) {
        flush()
      }
    }
    if (objectsWritten > 0) {
      flush()
    } else {
      writer.revertPartialWritesAndClose()
    }
    success = true
  } finally {
    if (success) {
      writer.close()
    } else {
      // This code path only happens if an exception was thrown above before we set success;
      // close our stuff and let the exception be thrown further
      writer.revertPartialWritesAndClose()
      if (file.exists()) {
        if (!file.delete()) {
          logWarning(s"Error deleting ${file}")
        }
      }
    }
  }
  SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}

遍历排好序的iterator,通过DisBlockObjectWiter写入溢写文件,每次溢写生成一个文件。在溢写过程中会批量flush,根据参数spark.shuffle.spill.batchSize默认只为10000,达到该值则执行flush,将数据批量刷入磁盘。最后将溢写的文件等信息包装成一个SpilledFile。

接下来需要回到SortShuffleWriter中的write方法。

val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)

创建最终的数据文件代码中的tmp变量。调用soter(ExternalSorter)的writePartitionedFile方法,合并写入最终的数据文件。

// ExternalSorter.scala

/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
    blockId: BlockId,
    outputFile: File): Array[Long] = {


  // Track location of each range in the output file
  val lengths = new Array[Long](numPartitions)
  val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
    context.taskMetrics().shuffleWriteMetrics)


  if (spills.isEmpty) {
    // Case where we only have in-memory data
    val collection = if (aggregator.isDefined) map else buffer
    val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
    while (it.hasNext) {
      val partitionId = it.nextPartition()
      while (it.hasNext && it.nextPartition() == partitionId) {
        it.writeNext(writer)
      }
      val segment = writer.commitAndGet()
      lengths(partitionId) = segment.length
    }
  } else {
    // We must perform merge-sort; get an iterator by partition and write everything directly.
    for ((id, elements) <- this.partitionedIterator) {
      if (elements.hasNext) {
        for (elem <- elements) {
          writer.write(elem._1, elem._2)
        }
        val segment = writer.commitAndGet()
        lengths(id) = segment.length
      }
    }
  }


  writer.close()
  context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
  context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
  context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)


  lengths
}
  •  先判断是否只存在内存数据,如果是只存在内存数据则对内存数据进行排序然后写入最终的数据文件。将每个目标partition(reduce对应的partition)写入数据的长度保存下来,用于后面生成index文件。

  • 如果存在溢写文件,则先执行partitionedIterator方法,将溢写的多个文件和剩下的内存缓存数据进行合并得到最终的Iterator。

// ExternalSorter.scala

/**
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
* partition without reading the previous one). Guaranteed to return a key-value pair for each
* partition, in order of partition ID.
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
* Exposed for testing.
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
  val usingMap = aggregator.isDefined
  val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
  if (spills.isEmpty) {
    // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
    // we don't even need to sort by anything other than partition ID
    if (!ordering.isDefined) {
      // The user hasn't requested sorted keys, so only sort by partition ID, not key
      groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
    } else {
      // We do need to sort by both partition ID and key
      groupByPartition(destructiveIterator(
        collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
    }
  } else {
    // Merge spilled and in-memory data
    merge(spills, destructiveIterator(
      collection.partitionedDestructiveSortedIterator(comparator)))
  }
}

merge方法:

// ExternalSorter.scala
/**
* Merge a sequence of sorted files, giving an iterator over partitions and then over elements
* inside each partition. This can be used to either write out a new file or return data to
* the user.
*
* Returns an iterator over all the data written to this object, grouped by partition. For each
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
*/
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
    : Iterator[(Int, Iterator[Product2[K, C]])] = {
  // 根据spills保存的溢写文件数组,创建SpillReader,得到readers: Seq[SpillReader]
  val readers = spills.map(new SpillReader(_))
  // 当前内存中数据迭代器
  val inMemBuffered = inMemory.buffered
   // 开始迭代每一个partititon(目标partition)
  (0 until numPartitions).iterator.map { p =>
    // 创建IteratorForPartition,注意内存中的数据是先按PartitionID来排序的,所以IteratorForPartition的hasNext的判断
    // 为override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId
    val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      // 调用每个SpillReader的readNextPartition()方法,得到每个溢写文件的数据迭代器。
    val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
   // 如果是存在aggregator,则调用mergeWithAggregation
    if (aggregator.isDefined) {
      // Perform partial aggregation across partitions
      (p, mergeWithAggregation(
        iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
    } else if (ordering.isDefined) {
      // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
      // sort the elements without trying to merge them
      // 调用mergeSort
      (p, mergeSort(iterators, ordering.get))
    } else {
      // 直接
      (p, iterators.iterator.flatten)
    }
  }
}

最终返回的是对每个partition的数据进行合并后的迭代器。然后在writePartitionedFile方法中逐条写入记录到文件中。

最后再执行shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)生成索引文件。

UnsafeShuffleWriter后面再单独整理。

 

 

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值