看ShuffleMapTask#runTask
override def runTask(context: TaskContext): MapStatus = {
。。。。
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
首先获取shuffleMananger。从shuffleManager中获取manager的writer方法,
这里要说明下shuffleManager
ShuffleManager是一个接口,主要功能有:
( l ) registerShuffle:每个 RDD 在构建它的父依赖 (这里特指 ShuffleDependency)时, 都会先注册到 ShuffleManager,获取 ShuffleHandler,用于后续数据块的读写等。
(2) getWriter: 可以通过 ShuffleHandler 获取数据块写入器,写数据时通过 Shuffle 的块 解析器 shuffleBlockResolver,获取写入位置(通常将写入位置抽象为 Bucket,位置的选拌则 由洗牌的规则,即 Shu旺le 的分区器决定〉,然后将数据写入到相应位置〈理论上,位置可以 位于任何能存储数据的地方,包括磁盘、内存或其他存储框架等,目前在可插拔框架的儿科 实现中 , Spark 与 Hadoop 一样都采用磁盘的方式进行存储,主要目的是为了节约内存,同时
提高容错性)。
(3) getReader:可以通过 ShuffleHandIer获取数据块读取器,然后通过 Shuffle 的块解析
器 shuffleBlockResolver,获取指定数据块 。
(4) unregisterShuffle:与注册对应,用于删除元数据等后续清理操作。
( 5) shuffleBlockResolver: Shuffle 的块解析器,通过该解析器,为数据块的读写提供支撑层,便于抽象具体的实现细节。
我们之前在性能调优的时候说过有个参数spark.shuffle.manager
这个可以指定选用的shuffleManager,在2.0之后HashShuffleManager已经被移除
所以我们主要看sortShuffleMananger
getWriter方法:
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
这边有我们比较熟悉的两个writer
BypassMergeSortShuffleWriter和SortShuffleWriter 这个在我们性能调优参数中有说明过,这里不详细说了。我们以SortShuffleWriter为例。这边我们先将rdd.iterator方法
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
}
if分支是从cache中获取,后续说,进入computeOrReadCheckpoint
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
if分支是从checkpoint中获取,后续说
看compute。这是一个抽象方法,我们随便找一个RDD的实现类,MapPartitionRDD#compute
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
这里的f就是调用我们算子中的函数。
我们进入SortShuffleWriter的write方法
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
首先会初始化一个ExternalSorter 这个东西具体不深入讲,我们只要知道这玩意是用来排序、聚合、缓存的
进入sorter#insertAll方法
while (records.hasNext) {
addElementsRead()
val kv = records.next()
map.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
这里会优先把数据写入到内存缓存中,也就是map,然后调用maybeSpillCollection,如果太多会写磁盘
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
maySpill来判断是否会写磁盘,
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted =
taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}
首先每获取32k数据,判断如果当前缓冲区超过myMemoryThreshold,就会尝试扩大缓冲区一倍,如果扩大后没有大大一倍,就会落盘,
spill(collection)是落盘动作
这里就不细说了
回到SortShuffleWriter#write方法
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
这里会把多个文件合并成一个
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
接下去则会写索引文件
索引文件主要用处是记录每个partion在一个shuffle文件中的变一辆,以方便拉取数据
最后返回MapStatus。
这边用一张图来总结下SortShuffleWriter的过程
上面长篇大论把shuffleMapTask#runTask方法阐述了一次
ResultMap的runTask方法相对简单,
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
func(context, rdd.iterator(partition, context))
这里就是进行一些简单的序列化,然后执行函数。因为ResultTask的结果不用给任何人用的,直接写到磁盘或者HDFS,或者返回。