Spark-MapOutputTracker 源码解析


MapOutputTracker 一共有2种类型,一个是 MapOutputTrackerMaster,另一个是 MapOutputTrackerWorker

MapOutputTrackerMaster

当new 这个MapOutputTrackerMaster对象的时候,会传递三个参数conf, broadcastManager, isLocal.
isLocal 在 yarn-cluster 模式下是false。
下面详细看看这个类,属性:

//这个参数 控制 shuffle map端输出数据 是否进行 广播
private val minSizeForBroadcast =
    conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
//是否支持 reducer 数据本地化
private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
private val SHUFFLE_PREF_MAP_THRESHOLD = 1000
private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000
private val REDUCER_PREF_LOCS_FRACTION = 0.2
//保存所有的shuffle的状态,这里的Int就是 shuffle 的ID
val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala
//rpc 的最大 messaage size 默认 128M,可以自己配置
private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
//存放 map out request 的 阻塞队列
private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
//处理上面 map out request 的 阻塞队列 的线程池,默认以8个线程处理消息
private val threadpool: ThreadPoolExecutor = {
    val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }
//判断 最小的广播阀值 是否大于 最大的 rpc msg 的大小
if (minSizeForBroadcast > maxRpcMessageSize) {
    val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " +
      s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " +
      "message that is too large."
    logError(msg)
    throw new IllegalArgumentException(msg)
  }

//投递GetMapOutputMessage消息 到mapOutputRequests中
def post(message: GetMapOutputMessage): Unit = {
    mapOutputRequests.offer(message)
  }
//一个标示 null 的GetMapOutputMessage消息
private val PoisonPill = new GetMapOutputMessage(-99, null)
//获取 缓存的 广播 数量
private[spark] def getNumCachedSerializedBroadcast: Int = {
//shuffleStatuses 里面保存的是 所有的 shuffle,每个shuffle里面是保存了这个shuffle的所有的mapStatus mapOutput,而每个shuffle 里面 都可能使用 广播的形式来序列化 数据
    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
  }
//注册一个shuffle 到 shuffleStatuses,这个shuffle的 ID和分区数
def registerShuffle(shuffleId: Int, numMaps: Int) {
    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
      throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
    }
  }
//注册一个shuffle的 map 和状态 到 shuffleStatuses
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
    shuffleStatuses(shuffleId).addMapOutput(mapId, status)
  }
//注销 MapOutput 根据 shuffleID,mapID,bmAddress
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
    shuffleStatuses.get(shuffleId) match {
      case Some(shuffleStatus) =>
        shuffleStatus.removeMapOutput(mapId, bmAddress)
        incrementEpoch()
      case None =>
        throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
    }
  }

//注销 某个Shuffle 根据 shuffleID,清理这个shuffle的数据
def unregisterShuffle(shuffleId: Int) {
    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
      shuffleStatus.invalidateSerializedMapOutputStatusCache()
    }
  }

//remove OutputsOnHost
def removeOutputsOnHost(host: String): Unit = {
    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
    incrementEpoch()
  }
//remove OutputsOnExecutor
def removeOutputsOnExecutor(execId: String): Unit = {
    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
    incrementEpoch()
  }
//shuffles是否包含这个shuffleId
def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)

//根据shuffleId 获取numAvailableOutputs
def getNumAvailableOutputs(shuffleId: Int): Int = {
    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
  }
//根据shuffleId 获取 findMissingPartitions
def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
  }
//
def rangeGrouped(range: Range, size: Int): Seq[Range] = {
    val start = range.start
    val step = range.step
    val end = range.end
    for (i <- start.until(end, size * step)) yield {
      i.until(i + size * step, step)
    }
  }

def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
    val elementsPerBucket = numElements / numBuckets
    val remaining = numElements % numBuckets
    val splitPoint = (elementsPerBucket + 1) * remaining
    if (elementsPerBucket == 0) {
      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
    } else {
      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
        rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
    }
  }

内部类:

private class MessageLoop extends Runnable {
    override def run(): Unit = {
      try {
        while (true) {
          try {
            val data = mapOutputRequests.take()
             if (data == PoisonPill) {
             //这里的目标是为了 stop掉所有的 处理msg的线程
              // Put PoisonPill back so that other MessageLoops can see it.
              mapOutputRequests.offer(PoisonPill)//而这里是为了 stop 其他的处理msg的线程
              return
            }
            val context = data.context
            val shuffleId = data.shuffleId
            val hostPort = context.senderAddress.hostPort
            logDebug("Handling request to send map output locations for shuffle " + shuffleId +
              " to " + hostPort)
            val shuffleStatus = shuffleStatuses.get(shuffleId).head
            context.reply(
              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

MapOutputTrackerWorker

/**
 * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
 * Note that this is not used in local-mode; instead, local-mode Executors access the
 * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon
 * superclass).
 */
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
//存放 每个 shuffle 的 MapStatus
  val mapStatuses: Map[Int, Array[MapStatus]] =
    new ConcurrentHashMap[Int, Array[MapStatus]]().asScala

  /** Remembers which map output locations are currently being fetched on an executor. */
  private val fetching = new HashSet[Int]

  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    val statuses = getStatuses(shuffleId)
    try {
      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) //statuses -》所在分区 的数据 转化为 location,【shuffleBlockId,size】
    } catch {
      case e: MetadataFetchFailedException =>
        // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
        mapStatuses.clear()
        throw e
    }
  }

  /**
   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
   * on this array when reading it, because on the driver, we may be changing it in place.
   *
   * (It would be nice to remove this restriction in the future.)
   */
  // 先从本节点 获取 这个 shuffleId 的 mapStatuses,没有的话 从driver 上获取
  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
    val statuses = mapStatuses.get(shuffleId).orNull //拿到这个 shuffleId 的 mapStatuses
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      val startTime = System.currentTimeMillis
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        while (fetching.contains(shuffleId)) {
          try {
            fetching.wait()
          } catch {
            case e: InterruptedException =>
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId
        }
      }

      if (fetchedStatuses == null) {
        // We won the race to fetch the statuses; do so
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) //向 driver 发送 GetMapOutputStatuses(shuffleId) 这里的响应的响应在 MapOutputTrackerMasterEndpoint 的
          //receiveAndReply 方法里面 最终结果得到 这个 shuffleId 的 Statuses
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) //反序列化 得到Array[MapStatus]
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses) //然后 存储到 本节点的 mapStatuses 里面
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
        s"${System.currentTimeMillis - startTime} ms")

      if (fetchedStatuses != null) {
        fetchedStatuses
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      statuses
    }
  }


  /** Unregister shuffle data. */
  //移除 这个 shuffleId 的 mapStatuses信息
  def unregisterShuffle(shuffleId: Int): Unit = {
    mapStatuses.remove(shuffleId)
  }

  /**
   * Called from executors to update the epoch number, potentially clearing old outputs
   * because of a fetch failure. Each executor task calls this with the latest epoch
   * number on the driver at the time it was created.
   */
  //更新 本机的 epoch 在 org.apache.spark.executor.Executor 的run 方法 里面 会使用这个 方法来更新 executor 的 epoch 值
  def updateEpoch(newEpoch: Long): Unit = {
    epochLock.synchronized {
      if (newEpoch > epoch) {
        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
        epoch = newEpoch
        mapStatuses.clear()
      }
    }
  }
}

abstract MapOutputTracker

这个类是公共的 抽象类:

//需要传入 SparkConf
abstract class MapOutputTracker(conf: SparkConf) extends Logging

属性:

//以endpoint的形式存在 driver中
var trackerEndpoint: RpcEndpointRef = _ //trace Edpoints
protected var epoch: Long = 0 //driver段的一个 计数
protected val epochLock = new AnyRef

方法:

//这个消息需要 返回信息
//发送一个msg给trackerEndpoint,获取其的结果在默认timeout内
protected def askTracker[T: ClassTag](message: Any): T = {
    try {
      trackerEndpoint.askSync[T](message)
    } catch {
      case e: Exception =>
        logError("Error communicating with MapOutputTracker", e)
        throw new SparkException("Error communicating with MapOutputTracker", e)
    }
  }
//发送一个one-way message 给trackerEndpoint,不会返回结果
protected def sendTracker(message: Any) {
    val response = askTracker[Boolean](message)
    if (response != true) {
      throw new SparkException(
        "Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
    }
  }
//获取 executor 的 map端输出 大小
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
  }
//获取 executor 的 map端输出 大小
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])]
//注销 shuffle
def unregisterShuffle(shuffleId: Int): Unit
//stop
def stop() {}

class ShuffleStatus

这个类用来保存一个shuffle的状态信息。
属性:

//持有一个 MapStatus的array
//这里的numPartitions数就是 以后的map out的总数量
val mapStatuses = new Array[MapStatus](numPartitions)
//缓存的 序列化过的MapStatus
private[this] var cachedSerializedMapStatus: Array[Byte] = _
//缓存的 序列化过的Broadcast
private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
//统计 可用的 output
private[this] var _numAvailableOutputs: Int = 0

方法:

//在内部mapStatuses中 增加一个 status ,已存在的话会 覆盖掉
def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
    if (mapStatuses(mapId) == null) {
      _numAvailableOutputs += 1
      invalidateSerializedMapOutputStatusCache()
    }
    mapStatuses(mapId) = status
  }

//Clears the cached serialized map output statuses.
def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
    if (cachedSerializedBroadcast != null) {
      // Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
      Utils.tryLogNonFatalError {
        // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
        // RPCs to dead executors.
        cachedSerializedBroadcast.destroy(blocking = false)
      }
      cachedSerializedBroadcast = null
    }
    cachedSerializedMapStatus = null
  }
//根据mapId和bmAddress 移除 Statuse
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
      _numAvailableOutputs -= 1//可用次数 - 1
      mapStatuses(mapId) = null 
      invalidateSerializedMapOutputStatusCache() //Clears the cached serialized map output statuses.
    }
  }

//通过Filter函数 移除Outputs
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
    for (mapId <- 0 until mapStatuses.length) {
      if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
        _numAvailableOutputs -= 1
        mapStatuses(mapId) = null
        invalidateSerializedMapOutputStatusCache()
      }
    }
  }
//根据 mapStatuses中的 location 的host remove output
def removeOutputsOnHost(host: String): Unit = {
    removeOutputsByFilter(x => x.host == host)
  }
//根据 mapStatuses中的 location 的host remove output
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
    removeOutputsByFilter(x => x.executorId == execId)
  }
//获取可用的 Outputs
def numAvailableOutputs: Int = synchronized {
    _numAvailableOutputs
  }
//获取 mapStatuses里面 null 的数目
def findMissingPartitions(): Seq[Int] = synchronized {
    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
    assert(missing.size == numPartitions - _numAvailableOutputs,
      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
    missing
  }

//
def serializedMapStatus(
      broadcastManager: BroadcastManager,
      isLocal: Boolean,
      minBroadcastSize: Int): Array[Byte] = synchronized {
    if (cachedSerializedMapStatus eq null) {
      val serResult = MapOutputTracker.serializeMapStatuses(
          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
      cachedSerializedMapStatus = serResult._1
      cachedSerializedBroadcast = serResult._2
    }
    cachedSerializedMapStatus
  }

object MapOutputTracker

这个是 MapOutputTracker的伴生对象。里面有一些公共的方法。

val ENDPOINT_NAME = "MapOutputTracker"
  private val DIRECT = 0 //直接序列化的标识
  private val BROADCAST = 1//广播序列化的标识
//序列化 mapStatuses
//先以GZIP的方式 序列化,如果结果超过minBroadcastSize的话,则会采用 广播 GZIP 序列化的方法即 此时 是把广播的东西 进行 GZIP序列化的
def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
      isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
    val out = new ByteArrayOutputStream
    out.write(DIRECT)
    val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
    Utils.tryWithSafeFinally {
      // Since statuses can be modified in parallel, sync on it
      statuses.synchronized {
        objOut.writeObject(statuses)
      }
    } {
      objOut.close()
    }
    val arr = out.toByteArray
    if (arr.length >= minBroadcastSize) {
      // Use broadcast instead.
      // Important arr(0) is the tag == DIRECT, ignore that while deserializing !
      val bcast = broadcastManager.newBroadcast(arr, isLocal)
      // toByteArray creates copy, so we can reuse out
      out.reset()
      out.write(BROADCAST)
      val oos = new ObjectOutputStream(new GZIPOutputStream(out))
      oos.writeObject(bcast)
      oos.close()
      val outArr = out.toByteArray
      logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
      (outArr, bcast)
    } else {
      (arr, null)
    }
  }

//反序列 得到 Array[MapStatus]
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
    assert (bytes.length > 0)

    def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
      val objIn = new ObjectInputStream(new GZIPInputStream(
        new ByteArrayInputStream(arr, off, len)))
      Utils.tryWithSafeFinally {
        objIn.readObject()
      } {
        objIn.close()
      }
    }

    bytes(0) match {
      case DIRECT =>
        deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
      case BROADCAST =>
        // deserialize the Broadcast, pull .value array out of it, and then deserialize that
        val bcast = deserializeObject(bytes, 1, bytes.length - 1).
          asInstanceOf[Broadcast[Array[Byte]]]
        logInfo("Broadcast mapstatuses size = " + bytes.length +
          ", actual size = " + bcast.value.length)
        // Important - ignore the DIRECT tag ! Start from offset 1
        deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
      case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
    }
  }

trait MapStatus

这个是一个接口,里面有2个方法:

//定位这个task在哪里运行的
def location: BlockManagerId
//评估这个reducer block的 size
def getSizeForBlock(reduceId: Int): Long

伴生对象 object MapStatus

里面有4个方法:

//根据BlockManagerId,uncompressedSizes正常的大小size,和2000比较
//决定返回哪一种的MapStatus (HighlyCompressedMapStatus|CompressedMapStatus)
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
    if (uncompressedSizes.length > 2000) {
      HighlyCompressedMapStatus(loc, uncompressedSizes)
    } else {
      new CompressedMapStatus(loc, uncompressedSizes)
    }
  }

  private[this] val LOG_BASE = 1.1

  //计算压缩的大小
  def compressSize(size: Long): Byte = {
    if (size == 0) {
      0
    } else if (size <= 1L) {
      1
    } else {//math.log e的自然对数
      math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
    }
  }

  //返回 压缩量 解压后的大小量
  def decompressSize(compressedSize: Byte): Long = {
    if (compressedSize == 0) {
      0
    } else {
      math.pow(LOG_BASE, compressedSize & 0xFF).toLong
    }
  }

CompressedMapStatus

详细看看:

//构造方法需要两个参数 BlockManagerId,compressedSizes(序列化后的array byte)
private[spark] class CompressedMapStatus(
    private[this] var loc: BlockManagerId,
    private[this] var compressedSizes: Array[Byte])
  extends MapStatus with Externalizable {
//只用做 反序列化使用
  protected def this() = this(null, null.asInstanceOf[Array[Byte]])  // For deserialization only

  def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
    this(loc, uncompressedSizes.map(MapStatus.compressSize))
  }

  override def location: BlockManagerId = loc
//获取 某个reducer 的 正常 size
  override def getSizeForBlock(reduceId: Int): Long = {
    MapStatus.decompressSize(compressedSizes(reduceId))
  }
//写序列化压缩后的数据
  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
    loc.writeExternal(out)
    out.writeInt(compressedSizes.length)
    out.write(compressedSizes)
  }
//读 序列化的压缩数据
  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
    loc = BlockManagerId(in)
    val len = in.readInt()
    compressedSizes = new Array[Byte](len)
    in.readFully(compressedSizes)
  }
}

HighlyCompressedMapStatus

详细看看:

//构造方法需要5个参数:
//BlockManagerId,numNonEmptyBlocks非空的block数目,emptyBlocks 一个bitmap保存的是 null的block
//非空和非大size的 size的平均值,hugeBlockSizes 大size的map 对象
private[spark] class HighlyCompressedMapStatus private (
    private[this] var loc: BlockManagerId,
    private[this] var numNonEmptyBlocks: Int,
    private[this] var emptyBlocks: RoaringBitmap,
    private[this] var avgSize: Long,
    private var hugeBlockSizes: Map[Int, Byte])
  extends MapStatus with Externalizable {

  // loc could be null when the default constructor is called during deserialization
  require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
    "Average size can only be zero for map stages that produced no output")
//只用做 反序列化使用
  protected def this() = this(null, -1, null, -1, null)  // For deserialization only

  override def location: BlockManagerId = loc
//获取 reducer的 正常size null block size是0、大size中有的话 返回否则 返回avgSize
  override def getSizeForBlock(reduceId: Int): Long = {
    assert(hugeBlockSizes != null)
    if (emptyBlocks.contains(reduceId)) {
      0
    } else {
      hugeBlockSizes.get(reduceId) match {
        case Some(size) => MapStatus.decompressSize(size)
        case None => avgSize
      }
    }
  }
//写序列化压缩后的数据
  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
    loc.writeExternal(out)
    emptyBlocks.writeExternal(out)
    out.writeLong(avgSize)
    out.writeInt(hugeBlockSizes.size)
    hugeBlockSizes.foreach { kv =>
      out.writeInt(kv._1)
      out.writeByte(kv._2)
    }
  }
//读 序列化的压缩数据
  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
    loc = BlockManagerId(in)
    emptyBlocks = new RoaringBitmap()
    emptyBlocks.readExternal(in)
    avgSize = in.readLong()
    val count = in.readInt()
    val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
    (0 until count).foreach { _ =>
      val block = in.readInt()
      val size = in.readByte()
      hugeBlockSizesArray += Tuple2(block, size)
    }
    hugeBlockSizes = hugeBlockSizesArray.toMap
  }
}

object HighlyCompressedMapStatus

//主要用来 计算 构造HighlyCompressedMapStatus的5个参数
private[spark] object HighlyCompressedMapStatus {
  def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
    var i = 0
    var numNonEmptyBlocks: Int = 0
    var numSmallBlocks: Int = 0
    var totalSmallBlockSize: Long = 0
   
    val emptyBlocks = new RoaringBitmap()
    val totalNumBlocks = uncompressedSizes.length
    val threshold = Option(SparkEnv.get)
      .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
      .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
    val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
    //根据 设置的阈值 放到不用的 类别block中
    while (i < totalNumBlocks) {
      val size = uncompressedSizes(i)
      if (size > 0) {
        numNonEmptyBlocks += 1
        // Huge blocks are not included in the calculation for average size, thus size for smaller
        // blocks is more accurate.
        if (size < threshold) {
          totalSmallBlockSize += size
          numSmallBlocks += 1
        } else {
          hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
        }
      } else {
        emptyBlocks.add(i)
      }
      i += 1
    }
    val avgSize = if (numSmallBlocks > 0) {
      totalSmallBlockSize / numSmallBlocks
    } else {
      0
    }
    emptyBlocks.trim()
    emptyBlocks.runOptimize()
    new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
      hugeBlockSizesArray.toMap)
  }
}

ShuffleStatus

这个类里面会使用MapStatus的借口和子类。
下面来看看:

//构造方法需要指定有多少个 分区
private class ShuffleStatus(numPartitions: Int) {
//保存每个分区的 mapStatus的状态
  val mapStatuses = new Array[MapStatus](numPartitions)
//用来保存 已经 序列化的mapStatus Byte数据
  private[this] var cachedSerializedMapStatus: Array[Byte] = _
//保存 已经序列化过的 广播数据
  private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
//可用的 output统计
  private[this] var _numAvailableOutputs: Int = 0

//注册一个 map output,如果已存在则会覆盖,_numAvailableOutputs会更新
  def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
    if (mapStatuses(mapId) == null) {
      _numAvailableOutputs += 1
      invalidateSerializedMapOutputStatusCache()
    }
    mapStatuses(mapId) = status
  }

  //移除 一个map output 根据mapID 即分区id 和 map out中的location 
  def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
      _numAvailableOutputs -= 1
      mapStatuses(mapId) = null
      invalidateSerializedMapOutputStatusCache()
    }
  }

 //移除一个host的所有的 map out
  def removeOutputsOnHost(host: String): Unit = {
    removeOutputsByFilter(x => x.host == host)
  }

  //移除 一个map output 根据execId
  def removeOutputsOnExecutor(execId: String): Unit = synchronized {
    removeOutputsByFilter(x => x.executorId == execId)
  }

  //移除的 一个函数方法
  def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
    for (mapId <- 0 until mapStatuses.length) {
      if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
        _numAvailableOutputs -= 1
        mapStatuses(mapId) = null
        invalidateSerializedMapOutputStatusCache()
      }
    }
  }

 //获取可用的 map output
  def numAvailableOutputs: Int = synchronized {
    _numAvailableOutputs
  }

  // mapStatuses 中 null 的 id seq
  def findMissingPartitions(): Seq[Int] = synchronized {
    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
    assert(missing.size == numPartitions - _numAvailableOutputs,
      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
    missing
  }

  //序列化 所有的mapStatus
  def serializedMapStatus(
      broadcastManager: BroadcastManager,
      isLocal: Boolean,
      minBroadcastSize: Int): Array[Byte] = synchronized {
    if (cachedSerializedMapStatus eq null) {
      val serResult = MapOutputTracker.serializeMapStatuses(
          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
      cachedSerializedMapStatus = serResult._1 //这里是 序列化的 结果 可能里面有 广播的信息
      cachedSerializedBroadcast = serResult._2 //这里可能是  null,或者是广播的结果
    }
    cachedSerializedMapStatus
  }

  // 是否 有广播
  def hasCachedSerializedBroadcast: Boolean = synchronized {
    cachedSerializedBroadcast != null
  }

  //函数作用 的方法
  def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
    f(mapStatuses)
  }

  /**
   * Clears the cached serialized map output statuses.
   */
  def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
    if (cachedSerializedBroadcast != null) {
      // Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
      Utils.tryLogNonFatalError {
        // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
        // RPCs to dead executors.
        cachedSerializedBroadcast.destroy(blocking = false)
      }
      cachedSerializedBroadcast = null
    }
    cachedSerializedMapStatus = null
  }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值