Spark-MapOutputTracker 源码解析
MapOutputTracker 一共有2种类型,一个是 MapOutputTrackerMaster,另一个是 MapOutputTrackerWorker。
MapOutputTrackerMaster
当new 这个MapOutputTrackerMaster对象的时候,会传递三个参数conf, broadcastManager, isLocal.
isLocal 在 yarn-cluster 模式下是false。
下面详细看看这个类,属性:
//这个参数 控制 shuffle map端输出数据 是否进行 广播
private val minSizeForBroadcast =
conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
//是否支持 reducer 数据本地化
private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)
private val SHUFFLE_PREF_MAP_THRESHOLD = 1000
private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000
private val REDUCER_PREF_LOCS_FRACTION = 0.2
//保存所有的shuffle的状态,这里的Int就是 shuffle 的ID
val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala
//rpc 的最大 messaage size 默认 128M,可以自己配置
private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
//存放 map out request 的 阻塞队列
private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
//处理上面 map out request 的 阻塞队列 的线程池,默认以8个线程处理消息
private val threadpool: ThreadPoolExecutor = {
val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
//判断 最小的广播阀值 是否大于 最大的 rpc msg 的大小
if (minSizeForBroadcast > maxRpcMessageSize) {
val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " +
s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " +
"message that is too large."
logError(msg)
throw new IllegalArgumentException(msg)
}
//投递GetMapOutputMessage消息 到mapOutputRequests中
def post(message: GetMapOutputMessage): Unit = {
mapOutputRequests.offer(message)
}
//一个标示 null 的GetMapOutputMessage消息
private val PoisonPill = new GetMapOutputMessage(-99, null)
//获取 缓存的 广播 数量
private[spark] def getNumCachedSerializedBroadcast: Int = {
//shuffleStatuses 里面保存的是 所有的 shuffle,每个shuffle里面是保存了这个shuffle的所有的mapStatus mapOutput,而每个shuffle 里面 都可能使用 广播的形式来序列化 数据
shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
}
//注册一个shuffle 到 shuffleStatuses,这个shuffle的 ID和分区数
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
//注册一个shuffle的 map 和状态 到 shuffleStatuses
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
shuffleStatuses(shuffleId).addMapOutput(mapId, status)
}
//注销 MapOutput 根据 shuffleID,mapID,bmAddress
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMapOutput(mapId, bmAddress)
incrementEpoch()
case None =>
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
//注销 某个Shuffle 根据 shuffleID,清理这个shuffle的数据
def unregisterShuffle(shuffleId: Int) {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
shuffleStatus.invalidateSerializedMapOutputStatusCache()
}
}
//remove OutputsOnHost
def removeOutputsOnHost(host: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
incrementEpoch()
}
//remove OutputsOnExecutor
def removeOutputsOnExecutor(execId: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
incrementEpoch()
}
//shuffles是否包含这个shuffleId
def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
//根据shuffleId 获取numAvailableOutputs
def getNumAvailableOutputs(shuffleId: Int): Int = {
shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
}
//根据shuffleId 获取 findMissingPartitions
def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
}
//
def rangeGrouped(range: Range, size: Int): Seq[Range] = {
val start = range.start
val step = range.step
val end = range.end
for (i <- start.until(end, size * step)) yield {
i.until(i + size * step, step)
}
}
def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
val elementsPerBucket = numElements / numBuckets
val remaining = numElements % numBuckets
val splitPoint = (elementsPerBucket + 1) * remaining
if (elementsPerBucket == 0) {
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
} else {
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
}
}
内部类:
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = mapOutputRequests.take()
if (data == PoisonPill) {
//这里的目标是为了 stop掉所有的 处理msg的线程
// Put PoisonPill back so that other MessageLoops can see it.
mapOutputRequests.offer(PoisonPill)//而这里是为了 stop 其他的处理msg的线程
return
}
val context = data.context
val shuffleId = data.shuffleId
val hostPort = context.senderAddress.hostPort
logDebug("Handling request to send map output locations for shuffle " + shuffleId +
" to " + hostPort)
val shuffleStatus = shuffleStatuses.get(shuffleId).head
context.reply(
shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
MapOutputTrackerWorker
/**
* Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
* Note that this is not used in local-mode; instead, local-mode Executors access the
* MapOutputTrackerMaster directly (which is possible because the master and worker share a comon
* superclass).
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
//存放 每个 shuffle 的 MapStatus
val mapStatuses: Map[Int, Array[MapStatus]] =
new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
/** Remembers which map output locations are currently being fetched on an executor. */
private val fetching = new HashSet[Int]
override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) //statuses -》所在分区 的数据 转化为 location,【shuffleBlockId,size】
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
mapStatuses.clear()
throw e
}
}
/**
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
* on this array when reading it, because on the driver, we may be changing it in place.
*
* (It would be nice to remove this restriction in the future.)
*/
// 先从本节点 获取 这个 shuffleId 的 mapStatuses,没有的话 从driver 上获取
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
val statuses = mapStatuses.get(shuffleId).orNull //拿到这个 shuffleId 的 mapStatuses
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
val startTime = System.currentTimeMillis
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// We won the race to fetch the statuses; do so
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) //向 driver 发送 GetMapOutputStatuses(shuffleId) 这里的响应的响应在 MapOutputTrackerMasterEndpoint 的
//receiveAndReply 方法里面 最终结果得到 这个 shuffleId 的 Statuses
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) //反序列化 得到Array[MapStatus]
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses) //然后 存储到 本节点的 mapStatuses 里面
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
s"${System.currentTimeMillis - startTime} ms")
if (fetchedStatuses != null) {
fetchedStatuses
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses
}
}
/** Unregister shuffle data. */
//移除 这个 shuffleId 的 mapStatuses信息
def unregisterShuffle(shuffleId: Int): Unit = {
mapStatuses.remove(shuffleId)
}
/**
* Called from executors to update the epoch number, potentially clearing old outputs
* because of a fetch failure. Each executor task calls this with the latest epoch
* number on the driver at the time it was created.
*/
//更新 本机的 epoch 在 org.apache.spark.executor.Executor 的run 方法 里面 会使用这个 方法来更新 executor 的 epoch 值
def updateEpoch(newEpoch: Long): Unit = {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
epoch = newEpoch
mapStatuses.clear()
}
}
}
}
abstract MapOutputTracker
这个类是公共的 抽象类:
//需要传入 SparkConf
abstract class MapOutputTracker(conf: SparkConf) extends Logging
属性:
//以endpoint的形式存在 driver中
var trackerEndpoint: RpcEndpointRef = _ //trace Edpoints
protected var epoch: Long = 0 //driver段的一个 计数
protected val epochLock = new AnyRef
方法:
//这个消息需要 返回信息
//发送一个msg给trackerEndpoint,获取其的结果在默认timeout内
protected def askTracker[T: ClassTag](message: Any): T = {
try {
trackerEndpoint.askSync[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}
//发送一个one-way message 给trackerEndpoint,不会返回结果
protected def sendTracker(message: Any) {
val response = askTracker[Boolean](message)
if (response != true) {
throw new SparkException(
"Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
}
}
//获取 executor 的 map端输出 大小
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
}
//获取 executor 的 map端输出 大小
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])]
//注销 shuffle
def unregisterShuffle(shuffleId: Int): Unit
//stop
def stop() {}
class ShuffleStatus
这个类用来保存一个shuffle的状态信息。
属性:
//持有一个 MapStatus的array
//这里的numPartitions数就是 以后的map out的总数量
val mapStatuses = new Array[MapStatus](numPartitions)
//缓存的 序列化过的MapStatus
private[this] var cachedSerializedMapStatus: Array[Byte] = _
//缓存的 序列化过的Broadcast
private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
//统计 可用的 output
private[this] var _numAvailableOutputs: Int = 0
方法:
//在内部mapStatuses中 增加一个 status ,已存在的话会 覆盖掉
def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
if (mapStatuses(mapId) == null) {
_numAvailableOutputs += 1
invalidateSerializedMapOutputStatusCache()
}
mapStatuses(mapId) = status
}
//Clears the cached serialized map output statuses.
def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
if (cachedSerializedBroadcast != null) {
// Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
Utils.tryLogNonFatalError {
// Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
// RPCs to dead executors.
cachedSerializedBroadcast.destroy(blocking = false)
}
cachedSerializedBroadcast = null
}
cachedSerializedMapStatus = null
}
//根据mapId和bmAddress 移除 Statuse
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
_numAvailableOutputs -= 1//可用次数 - 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache() //Clears the cached serialized map output statuses.
}
}
//通过Filter函数 移除Outputs
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}
//根据 mapStatuses中的 location 的host remove output
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
}
//根据 mapStatuses中的 location 的host remove output
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
}
//获取可用的 Outputs
def numAvailableOutputs: Int = synchronized {
_numAvailableOutputs
}
//获取 mapStatuses里面 null 的数目
def findMissingPartitions(): Seq[Int] = synchronized {
val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
assert(missing.size == numPartitions - _numAvailableOutputs,
s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
missing
}
//
def serializedMapStatus(
broadcastManager: BroadcastManager,
isLocal: Boolean,
minBroadcastSize: Int): Array[Byte] = synchronized {
if (cachedSerializedMapStatus eq null) {
val serResult = MapOutputTracker.serializeMapStatuses(
mapStatuses, broadcastManager, isLocal, minBroadcastSize)
cachedSerializedMapStatus = serResult._1
cachedSerializedBroadcast = serResult._2
}
cachedSerializedMapStatus
}
object MapOutputTracker
这个是 MapOutputTracker的伴生对象。里面有一些公共的方法。
val ENDPOINT_NAME = "MapOutputTracker"
private val DIRECT = 0 //直接序列化的标识
private val BROADCAST = 1//广播序列化的标识
//序列化 mapStatuses
//先以GZIP的方式 序列化,如果结果超过minBroadcastSize的话,则会采用 广播 GZIP 序列化的方法即 此时 是把广播的东西 进行 GZIP序列化的
def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
val out = new ByteArrayOutputStream
out.write(DIRECT)
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
} {
objOut.close()
}
val arr = out.toByteArray
if (arr.length >= minBroadcastSize) {
// Use broadcast instead.
// Important arr(0) is the tag == DIRECT, ignore that while deserializing !
val bcast = broadcastManager.newBroadcast(arr, isLocal)
// toByteArray creates copy, so we can reuse out
out.reset()
out.write(BROADCAST)
val oos = new ObjectOutputStream(new GZIPOutputStream(out))
oos.writeObject(bcast)
oos.close()
val outArr = out.toByteArray
logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
(outArr, bcast)
} else {
(arr, null)
}
}
//反序列 得到 Array[MapStatus]
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
assert (bytes.length > 0)
def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
val objIn = new ObjectInputStream(new GZIPInputStream(
new ByteArrayInputStream(arr, off, len)))
Utils.tryWithSafeFinally {
objIn.readObject()
} {
objIn.close()
}
}
bytes(0) match {
case DIRECT =>
deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
case BROADCAST =>
// deserialize the Broadcast, pull .value array out of it, and then deserialize that
val bcast = deserializeObject(bytes, 1, bytes.length - 1).
asInstanceOf[Broadcast[Array[Byte]]]
logInfo("Broadcast mapstatuses size = " + bytes.length +
", actual size = " + bcast.value.length)
// Important - ignore the DIRECT tag ! Start from offset 1
deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
}
}
trait MapStatus
这个是一个接口,里面有2个方法:
//定位这个task在哪里运行的
def location: BlockManagerId
//评估这个reducer block的 size
def getSizeForBlock(reduceId: Int): Long
伴生对象 object MapStatus
里面有4个方法:
//根据BlockManagerId,uncompressedSizes正常的大小size,和2000比较
//决定返回哪一种的MapStatus (HighlyCompressedMapStatus|CompressedMapStatus)
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > 2000) {
HighlyCompressedMapStatus(loc, uncompressedSizes)
} else {
new CompressedMapStatus(loc, uncompressedSizes)
}
}
private[this] val LOG_BASE = 1.1
//计算压缩的大小
def compressSize(size: Long): Byte = {
if (size == 0) {
0
} else if (size <= 1L) {
1
} else {//math.log e的自然对数
math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
}
}
//返回 压缩量 解压后的大小量
def decompressSize(compressedSize: Byte): Long = {
if (compressedSize == 0) {
0
} else {
math.pow(LOG_BASE, compressedSize & 0xFF).toLong
}
}
CompressedMapStatus
详细看看:
//构造方法需要两个参数 BlockManagerId,compressedSizes(序列化后的array byte)
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
private[this] var compressedSizes: Array[Byte])
extends MapStatus with Externalizable {
//只用做 反序列化使用
protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
this(loc, uncompressedSizes.map(MapStatus.compressSize))
}
override def location: BlockManagerId = loc
//获取 某个reducer 的 正常 size
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
//写序列化压缩后的数据
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
//读 序列化的压缩数据
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
}
}
HighlyCompressedMapStatus
详细看看:
//构造方法需要5个参数:
//BlockManagerId,numNonEmptyBlocks非空的block数目,emptyBlocks 一个bitmap保存的是 null的block
//非空和非大size的 size的平均值,hugeBlockSizes 大size的map 对象
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
private var hugeBlockSizes: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
//只用做 反序列化使用
protected def this() = this(null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
//获取 reducer的 正常size null block size是0、大size中有的话 返回否则 返回avgSize
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
0
} else {
hugeBlockSizes.get(reduceId) match {
case Some(size) => MapStatus.decompressSize(size)
case None => avgSize
}
}
}
//写序列化压缩后的数据
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
out.writeInt(hugeBlockSizes.size)
hugeBlockSizes.foreach { kv =>
out.writeInt(kv._1)
out.writeByte(kv._2)
}
}
//读 序列化的压缩数据
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
val count = in.readInt()
val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
(0 until count).foreach { _ =>
val block = in.readInt()
val size = in.readByte()
hugeBlockSizesArray += Tuple2(block, size)
}
hugeBlockSizes = hugeBlockSizesArray.toMap
}
}
object HighlyCompressedMapStatus
//主要用来 计算 构造HighlyCompressedMapStatus的5个参数
private[spark] object HighlyCompressedMapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
var i = 0
var numNonEmptyBlocks: Int = 0
var numSmallBlocks: Int = 0
var totalSmallBlockSize: Long = 0
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
val threshold = Option(SparkEnv.get)
.map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
.getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
//根据 设置的阈值 放到不用的 类别block中
while (i < totalNumBlocks) {
val size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
// Huge blocks are not included in the calculation for average size, thus size for smaller
// blocks is more accurate.
if (size < threshold) {
totalSmallBlockSize += size
numSmallBlocks += 1
} else {
hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
}
} else {
emptyBlocks.add(i)
}
i += 1
}
val avgSize = if (numSmallBlocks > 0) {
totalSmallBlockSize / numSmallBlocks
} else {
0
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
hugeBlockSizesArray.toMap)
}
}
ShuffleStatus
这个类里面会使用MapStatus的借口和子类。
下面来看看:
//构造方法需要指定有多少个 分区
private class ShuffleStatus(numPartitions: Int) {
//保存每个分区的 mapStatus的状态
val mapStatuses = new Array[MapStatus](numPartitions)
//用来保存 已经 序列化的mapStatus Byte数据
private[this] var cachedSerializedMapStatus: Array[Byte] = _
//保存 已经序列化过的 广播数据
private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
//可用的 output统计
private[this] var _numAvailableOutputs: Int = 0
//注册一个 map output,如果已存在则会覆盖,_numAvailableOutputs会更新
def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
if (mapStatuses(mapId) == null) {
_numAvailableOutputs += 1
invalidateSerializedMapOutputStatusCache()
}
mapStatuses(mapId) = status
}
//移除 一个map output 根据mapID 即分区id 和 map out中的location
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
//移除一个host的所有的 map out
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
}
//移除 一个map output 根据execId
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
}
//移除的 一个函数方法
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}
//获取可用的 map output
def numAvailableOutputs: Int = synchronized {
_numAvailableOutputs
}
// mapStatuses 中 null 的 id seq
def findMissingPartitions(): Seq[Int] = synchronized {
val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
assert(missing.size == numPartitions - _numAvailableOutputs,
s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
missing
}
//序列化 所有的mapStatus
def serializedMapStatus(
broadcastManager: BroadcastManager,
isLocal: Boolean,
minBroadcastSize: Int): Array[Byte] = synchronized {
if (cachedSerializedMapStatus eq null) {
val serResult = MapOutputTracker.serializeMapStatuses(
mapStatuses, broadcastManager, isLocal, minBroadcastSize)
cachedSerializedMapStatus = serResult._1 //这里是 序列化的 结果 可能里面有 广播的信息
cachedSerializedBroadcast = serResult._2 //这里可能是 null,或者是广播的结果
}
cachedSerializedMapStatus
}
// 是否 有广播
def hasCachedSerializedBroadcast: Boolean = synchronized {
cachedSerializedBroadcast != null
}
//函数作用 的方法
def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
f(mapStatuses)
}
/**
* Clears the cached serialized map output statuses.
*/
def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
if (cachedSerializedBroadcast != null) {
// Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
Utils.tryLogNonFatalError {
// Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
// RPCs to dead executors.
cachedSerializedBroadcast.destroy(blocking = false)
}
cachedSerializedBroadcast = null
}
cachedSerializedMapStatus = null
}
}