Shuffle Write 请看 Shuffle Write解析。
本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
// 获取存储数据位置的元数据
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// 每次远程请求传输的最大大小
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
// 用压缩加密来包装流
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance()
// 对每个流生成K/V迭代器
val recordIter = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// 每条记录读取后更新任务度量
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
// 生成完整的迭代器
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// 在map端已经聚合一次了
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// 只在reduce端聚合
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 若需要全局排序
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 获取元数据信息
val statuses = getStatuses(shuffleId)
// 转换格式并得到指定partition的元数据信息
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}
- 传入shuffleId获取对应shuffle的所有元数据信息
- 转换格式并获取指定partition的元数据
跟进getStatuses:
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
// 直接从mapStatuses中获取
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
val startTime = System.currentTimeMillis
var fetchedStatuses: Array[MapStatus] = null
......
if (fetchedStatuses == null) {
// We won the race to fetch the statuses; do so
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
// 从远程获取元数据
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
// 反序列化
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")