
当提交任务集时,Spark会分成两种情况:ShuffleMapTask 和 ResultTask ,接下来分析一下 ShuffleMapTask 中主要做了哪些工作。

private[spark] class ShuffleMapTask(
    stageId: Int,
    taskBinary: Broadcast[Array[Byte]],
    partition: Partition,
    @transient private var locs: Seq[TaskLocation])
  extends Task[MapStatus](stageId, partition.index) with Logging {

  /** A constructor used only in test suites. This does not require passing in an RDD. */
  def this(partitionId: Int) {
    this(0, null, new Partition { override def index = 0 }, null)

  @transient private val preferredLocs: Seq[TaskLocation] = {
    if (locs == null) Nil else locs.toSet.toSeq

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
 <span style="color:#FF0000;">     val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])  //真正的计算RDD,主要工作就是将其中间结果写入本地磁盘
      return writer.stop(success = true).get</span>
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        throw e

  override def preferredLocations: Seq[TaskLocation] = preferredLocs

  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)

ShuffleManager 为 Trait , 共有两个子类,SortShuffleManager 和 HashShuffleManager ,默认在创建Executor的SparkEnv时,使用的是SortShuffleManager,我们就按照默认的进行分析。 查看其getWriter() 方法:

  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
      : ShuffleWriter[K, V] = {
    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
    shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) 
    new SortShuffleWriter(
      shuffleBlockManager, baseShuffleHandle, mapId, context)

继续跟踪 SortShuffleWriter 的write() 方法:

  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    if (dep.mapSideCombine) {  //在mapper端进行聚合
      if (!dep.aggregator.isDefined) {
        throw new IllegalStateException("Aggregator is empty for map-side combine")
      sorter = new ExternalSorter[K, V, C](
        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)  //传入参数,new一个实例对象
      <span style="color:#FF6666;">sorter.insertAll(records)  //具体的执行</span>
    } else { //aggregator为None的情况
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      sorter = new ExternalSorter[K, V, V](
        None, Some(dep.partitioner), None, dep.serializer)
    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)

    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)

继续跟踪 ExternalSorter 的 insertAll()方法:

  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined

    if (shouldCombine) {   //将数据加载进内存并进行合并
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      while (records.hasNext) {
        kv =
        map.changeValue((getPartition(kv._1), kv._1), update)  //map 其实是继承了AppendOnlyMap
        maybeSpillCollection(usingMap = true)  //判断内存不够时,是否spill到磁盘
    } else if (bypassMergeSort) {    //不进行合并排序
      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
      if (records.hasNext) {
        <span style="color:#FF6666;">spillToPartitionFiles( { kv =>
          ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])  //((partitionId,key),value)</span>
    } else {
      // Stick values into our buffer
      while (records.hasNext) {     //将数据放入缓冲区
        val kv =
        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)

我们只看其中的一个: spillToPartitionFiles()方法,源码如下:

  private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {

    // Create our file writers if we haven't done so yet
    if (partitionWriters == null) {  //创建分区文件写入器,每个reduce分区对应一个
      curWriteMetrics = new ShuffleWriteMetrics()
      partitionWriters = Array.fill(numPartitions) {
        // Because these files may be read during shuffle, their compression must be controlled by
        // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
        // createTempShuffleBlock here; see SPARK-3426 for more context.
        val (blockId, file) = diskBlockManager.createTempShuffleBlock() //临时存放shuffle中间结果的文件
        blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()  //每个map的分区都对应一个file

    // No need to sort stuff, just write each element out
    while (iterator.hasNext) {
      val elem =
      val partitionId = elem._1._1
      val key = elem._1._2
      val value = elem._2
      partitionWriters(partitionId).write((key, value))  //hash相同的key都会被同一个写入器写入

接下来,回到 SortShuffleWriter 的write() 方法中,看看后面还有什么操作,具体代码如下:

    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)   //真正的mapper端输出文件
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)   //一个map就是一个blockId,也就是一个文件
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)   //写分区的文件,关键的一步
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)  //每个partition对应的那部分数据在文件中的偏移量

    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)

继续查看关键的一步 writePartitionedFile() :

  def writePartitionedFile(
      blockId: BlockId,
      context: TaskContext,
      outputFile: File): Array[Long] = {

    // Track location of each range in the output file
    val lengths = new Array[Long](numPartitions)

    if (bypassMergeSort && partitionWriters != null) {  //不排序且写入器中有数据,默认当reduce分区数小于200时,不排序
      // We decided to write separate files for each partition, so just concatenate them. To keep
      // this simple we spill out the current in-memory collection so that everything is in files.
      spillToPartitionFiles(if (aggregator.isDefined) map else buffer)  //将内存map或缓冲区buffer中的数据,也写入文件
      partitionWriters.foreach(_.commitAndClose())  //flush and close
      var out: FileOutputStream = null
      var in: FileInputStream = null
      try {
        out = new FileOutputStream(outputFile, true)
        for (i <- 0 until numPartitions) {
          in = new FileInputStream(partitionWriters(i).fileSegment().file)
          val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) //写入到outPutFile中
          in = null
          lengths(i) = size   //将文件的offset保存到数组中
      } finally {
        if (out != null) {
        if (in != null) {
    } else {  //还有几种情况
      // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
      // partition and just write everything directly.
      for ((id, elements) <- this.partitionedIterator) {
        if (elements.hasNext) {
          val writer = blockManager.getDiskWriter(
            blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get)
          for (elem <- elements) {
          val segment = writer.fileSegment()
          lengths(id) = segment.length

    context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
    context.taskMetrics.diskBytesSpilled += diskBytesSpilled
    context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
      if (curWriteMetrics != null) {
        m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten
        m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime


接着还有一个比较关键的方法:partitionedIterator() :

  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
    val usingMap = aggregator.isDefined
    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
    if (spills.isEmpty && partitionWriters == null) {  //spills里没有数据,也没有数据写入磁盘中,所有的都在内存或缓冲区中
      // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
      // we don't even need to sort by anything other than partition ID
      if (!ordering.isDefined) { //不需要排序,只对分区id进行排序
        // The user hasn't requested sorted keys, so only sort by partition ID, not key
      } else {
        // We do need to sort by both partition ID and key
        groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) //需要对分区id和key.hashcode进行排序,先对分区id排,再对key排
    } else if (bypassMergeSort) {  //不排序,但磁盘中有数据,需要进行合并
      // Read data from each partition file and merge it together with the data in memory;
      // note that there's no ordering or aggregator in this case -- we just partition objects
      val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) { case (partitionId, values) =>
        (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
    } else {  //spills中有数据,合并spills和内存中的数据
      // Merge spilled and in-memory data
      merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))

