关于任务切分与运行可以参考:spark源码解析之三、任务切分与运行
spark运行流程分为资源环境准备和任务提交运行两个步骤,两个步骤交叉进行,当前以任务提交为主线进行源码分析。
资源环境准备线,可以参考spark源码解析之二、计算资源准备
关于源代码的前期准备可以参考:spark源码解析之一、整体概述
一、shuffle概述
Shuffle 机制是 Spark Core 的核心内容。在 Stage 和 Stage 之间,Spark 需要 Shuffle 数据。这个流程包含上一个 Stage 上的 Shuffle Write,中间的数据传输,以及下一个 Stage 的 Shuffle Read。如下图所示:
强调一点,shuffle分两个过程:shuffle读和shuffle写,shuffle读发生在一个stage的开始,shuffle写发生在ShuffleMapStage的结尾,shuffle读在前,shuffle写在后,且分布在一个stage的首尾,在后续的原码中也会有所体现。
在前几篇文章中分析了spark资源准备和stage切分、task提交的源代码,那么我们接下来分析任务task运行过程中必不可少的shuffle过程的原码。
task任务提交之后,在stage的最后阶段就是shuffle数据落盘的过程,数据落盘完成则标着者下一个stage的开始,下一个stage的数据来源就是上一个stage的shuffle数据文件。
spark中的task分为两类ResultTask和ShuffleMapTask,ResultTask作为最终阶段的task,写的过程主要是一些行动算子,不同的行动算子具有不同的逻辑,不具有代表性。而ShuffleMapTask具备了shuffle的通用功能读和写,所以主要从ShuffleMapTask的runtask开始。
shuffle从代码中可以看到,我们暂时先关注三件事情,ShuffleWriter的创建、Shuffle读以及shuffle写。
1.1 ShuffleMapTask.runTask
override def runTask(context: TaskContext): MapStatus = {
....
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
//一:ShuffleWriter的创建
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
//二:Shuffle写主要是rdd.iterator
//三:shuffle读主要是writer.write
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
....
}
}
二、ShuffleWriter的创建
注册ShuffleHandle,不同的ShuffleHandle用于创建不同的ShuffleWriter。从这个判断方法中可以看到,其实spark框架是优先判断是否符合bypass机制,如果不符合在判断是否是序列化shuffle机制,如果两者都不符合才是baseshuffle机制。其实这个过程就像去某地,地图肯定会规划多条路径,最终给定一条最优解,也从侧面说明这几种机制的顺序实现起来是越来越麻烦。
2.1 SortShuffleManager.registerShuffle
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
//判断是否使用忽略合并排序
if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
}
首先关注一下判断bypass机制的逻辑。
2.2 SortShuffleManager.shouldBypassMergeSort
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
//RDD依赖中已经明确定义map端预聚合
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
false
} else {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
//task分区数量小于200,bypassMergeThreshold默认值为200,可以调整
dep.partitioner.numPartitions <= bypassMergeThreshold
}
}
从源码逻辑可以发现bypass机制需要符合一些条件,shuffle算子不能支持map端聚合且stage中task数量不高于200的阈值。这个200的阈值可以通过spark.shuffle.sort.bypassMergeThreshold动态配置,在生产环境中如果运行环境资源允许,可以调大该阈值的配置,以便命中bypass的可能性,从而提升task运行效率。
2.3 SortShuffleWriter.getWriter
//不同handle创建不同shuffleWriter
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
}
根据不同的ShuffleHandle创建shuffleWrite。
三、Shuffle读
我们从前边的stage划分可以知道,stage划分的分界线是款依赖算子,一个stage的task数量取决于款依赖的分区数量,也就是说临界算子有多少个分区就会划分多少个task,其实每个task就是这一stage从前到后的算子封装,每个task封装的逻辑一样,只是读取处理的数据不同而已,所以,shuffle读就是正常算子的读,没什么特别,只是读取的是上一个stage的shuffle数据而已。
我们从writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])入手,最终追踪到rdd.compute方法,这个compute方法有多种实现,但是只有ShuffledRDD这一类的RDD才会有shuffle的读,所以就从ShuffledRDD.compute方法看起。
3.1 ShuffledRDD.compute
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
//注意这里特别关注一下getReader的参数
//dep.shuffleHandle用于获取前一阶段shuffle数据文件元数据
//split.index 上一阶段shuffle结果数据切片索引开始
//split.index + 1 上一阶段shuffle结果数据切片索引结束
//从这里可以看出一个reader只读取上一阶段shuffle结果数据的一个分区,这个shuffle溢写数据文件逻辑保持一致
//即上一阶段shufflewrite会将shuffle结果数据根据下一阶段分区数也是task数量进行落盘,一个task一份分区"一段数据"
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
3.2 BlockStoreShuffleReader.read
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
//注意这里第一个参数是从handle中获取的shuffleId,即上一个阶段的shuffle结果数据的标记
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// 设置每次拉取的数据大小,默认48M
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
//设置每次拉取的数据量最大值,默认为Int最大值
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
// 根据配置的压缩和解码方式包装流
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance()
// 为读取到的每个流创建KV迭代器
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
//读取数据后,更新读取量,后续用于记录跟踪和评估
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
//通过map操作每条记录,其实什么都没做,只是记录了读取的数据量
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// 设置可中断迭代器,以便取消task
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
//需要对读取的数据在map端进行聚合,比如reduceByKey会在map进行预聚合
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
//按照key对数据集进行聚合
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
//不关注预聚合,只关注value,将相同key的values进行归集,比如groupByKey
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
//未指定聚合器的映map端合并,直接什么都不做,数据保持原样,比如sortByKey对value没任何要求
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 如果指定排序器,则对数据集进行排序输出
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// 创建排序器ExternalSorter对结果集进行排序
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
//如果需要排序,则需要使用排序器对数据集进行排序
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
//不需要排序,数据保持原样
aggregatedIter
}
}
3.3 ExternalSorter.insertAll
参考4.1.1以及后续方法。
四、Shuffle写之SortShuffleWriter
SortShuffleWriter的写方法主要做了三件事情
1.对所有记录写入内存或者文件;
2.对数据溢写并合不同分区数据文件;
3.对数据文件创建索引文件。
override def write(records: Iterator[Product2[K, V]]): Unit = {
//注意这个sorter决定了后边溢写的方式
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
//map端预聚合需要对key进行排序 dep.keyOrdering
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
//无需map端预聚合,不需要对key进行排序 ordering = None
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
//对所有记录进行排序
sorter.insertAll(records)
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
//创建临时文件
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}
根据write代码逻辑,首先会根据shuffle中的算子是否需要map端预聚合,生成不同参数的排序器,这个排序器决定了后续数据的排序和数据文件的溢写。接下来就使用该排序器对所有数据记录进行归集操作,要么写入内存,如果内存不足以存储,则溢写到磁盘,注意这里的溢写,只要内存满就会写到文件,所以会产生很多的小文件,这一点在spill方法中会有所体现。
4.1 对所有记录写入内存或者文件
4.1.1 ExternalSorter.insertAll
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
val shouldCombine = aggregator.isDefined
//是否需要预聚合
if (shouldCombine) {
//map端预聚合使用AppendOnlyMap数据结构
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
//定义聚合函数,通过map数据结构对相同key数据进行聚合
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
//从这里可以看到,其实spark处理数据最终也是对数据逐条进行处理,没什么特殊之处
while (records.hasNext) {
//记录当前处理的记录数,该方法会进行+1操作
addElementsRead()
kv = records.next()
//由于map端预聚合,需要根据key值,对map中数据进行聚合更新
map.changeValue((getPartition(kv._1), kv._1), update)
//判断是否需要溢写操作,usingMap = true,因为只有map结构才能很好低支持预聚合
maybeSpillCollection(usingMap = true)
}
} else {
//如果不需要预聚合,直接插入即可
while (records.hasNext) {
//记录当前处理的记录数
addElementsRead()
val kv = records.next()
//直接插入buffer内存
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
//判断是否需要溢写操作,usingMap = false,这里直接使用缓冲区buffer
maybeSpillCollection(usingMap = false)
}
}
}
这里一定要注意,不管是哪种方式都会有是否需要溢写的判断,并且判断是在while循环中,也就是说,每条记录进来都会走一遍后续的流程。
4.1.2 ExternalSorter.maybeSpillCollection
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) {
estimatedSize = map.estimateSize()
//判断是否需要spill即溢写文件,注意这里是每来一条记录都会进行一次判断
if (maybeSpill(map, estimatedSize)) {
//如果溢写成功map结构初始化
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
//判断是否需要spill即溢写文件,注意这里是每来一条记录都会进行一次判断
if (maybeSpill(buffer, estimatedSize)) {
//如果溢写成功buffer缓冲区初始化
buffer = new PartitionedPairBuffer[K, C]
}
}
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}
不管是使用map还是使用buffer封装数据集,最终都要判断数据集是否达到溢写阈值,注意这里一直在内存中,并没有文件什么事情,且对数据集中每条数据都会进行一次判断,调用这个方法是在一个while循环中,可以参考3.1.1。
4.1.3 Spillable
@volatile private[this] var myMemoryThreshold = initialMemoryThreshold
//默认内存缓冲区大小为5M,可以通过spark.shuffle.spill.initialMemoryThreshold动态配置
private[this] val initialMemoryThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
//默认强制溢写数量为Long最大值,默认值很大,没有优化的必要
private[this] val numElementsForceSpillThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)
......
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
//写入数量是否是32的倍数,且当前内存是仍然允许写入
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
//从shuffle内存池中获取最多两倍的当前内存
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
//进来一条数据后,如果内存数量大于内存阈值则溢写文件
shouldSpill = currentMemory >= myMemoryThreshold
}
//这里是一个强制性判断,如果写入数量大于强制溢出阈值
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
//溢写之后数量清零
_elementsRead = 0
_memoryBytesSpilled += currentMemory
//释放内存
releaseMemory()
}
shouldSpill
}
每进来一条数据都要判断一下是否符合溢写的条件,如果需要溢写,就会进行溢写操作。
4.1.4 ExternalSorter.spill
//spills是一个数据集:数组,存放的是溢写临时文件
private val spills = new ArrayBuffer[SpilledFile]
.....
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
//注意溢写一次就会写入一个文件
spills += spillFile
}
......
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
: SpilledFile = {
// 因为这些文件可能在shuffle过程中被读取,所以它们的压缩必须使用spark.shuffle.compress压缩方式,
// 而不是shuffle溢写的压缩方式,因此我们需要在这里使用createTempShuffleBlock;
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
// 溢写文件之后需要更新变量
var objectsWritten: Long = 0
val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
val writer: DiskBlockObjectWriter =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
//批次大小
val batchSizes = new ArrayBuffer[Long]
//每个分区拥有元素数量列表
val elementsPerPartition = new Array[Long](numPartitions)
//刷写之后,更新相关变量
//这里只是定义了这个方法,被调用才会真正执行
def flush(): Unit = {
val segment = writer.commitAndGet()
batchSizes += segment.length
_diskBytesSpilled += segment.length
objectsWritten = 0
}
var success = false
try {
while (inMemoryIterator.hasNext) {
val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
inMemoryIterator.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
//写入累计数量已经达到批次阈值,则进行刷写,这个阈值是一个调优的对象
if (objectsWritten == serializerBatchSize) {
flush()
}
}
if (objectsWritten > 0) {
flush()
} else {
writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (success) {
//关闭流资源
writer.close()
} else {
writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
}
}
}
}
SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}
4.2 对数据溢写并合不同分区数据文件
4.2.1 ExternalSorter.writePartitionedFile
def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {
//跟踪输出文件中每个范围的位置
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
if (spills.isEmpty) {
//没有溢写,说明只需要处理内存数据
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
//将内存中数据写到文件,注意这里不是溢写
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
//按分区获取迭代器并直接写入到数据文件,在写入时使用同一个writer,说明写入了一个文件
//从这里我们可以得出结论:一个task只写了一个数据文件
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
lengths
}
4.2.2 groupByPartition
//从处理逻辑来看这里返回的数据结构的key根分区号相关
private def groupByPartition(data: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] =
{
val buffered = data.buffered
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
4.2.3 merge
// 合并已排序的文件,最终写一个新的文件或返回数据,这里的merge的是溢写的临时数据文件
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
//这里一定要注意,这里遍历的是分区数量,返回结果key就是分区索引位
(0 until numPartitions).iterator.map { p =>
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
//这里读取文件并合并分区内数据,将数据归集起来,并没有写入
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined) {
//跨分区执行部分聚合,主要用于处理给定的比较器。默认情况下,同一个key会放到同一个分区中,那么
//如果使用自定义的的排序器,不同的key可能相等,那就需要跨分区操作。
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
//不定义排序规则但是需要进行排序,由于同一个分区会有多个临时数据文件,所以需要在多个文件之间进行合并并排序
(p, mergeSort(iterators, ordering.get))
} else {
//如果排序聚合什么都没有,则将数据压平
(p, iterators.iterator.flatten)
}
}
}
4.3 对数据文件创建索引文件
4.3.1 IndexShuffleBlockResolver.writeIndexFileAndCommit
//临时文件合并后,些数据文件的索引文件
def writeIndexFileAndCommit(
shuffleId: Int,
mapId: Int,
lengths: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val indexTmp = Utils.tempFileWith(indexFile)
try {
//维护索引临时文件
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
//取每个数据块的长度,需要将其转换为偏移量然后维护索引临时文件
var offset = 0L
out.writeLong(offset)
//注意这里的lengths是分区的个数,也就是下一个阶段的task个数,从这里可以明确一个task只有一个索引文件
for (length <- lengths) {
offset += length
out.writeLong(offset)
}
} {
out.close()
}
val dataFile = getDataFile(shuffleId, mapId)
//每个执行器只有一个IndexShuffleBlockResolver,此同步确保以下检查和重命名是原子的
synchronized {
//判断索引文件是否已经存在
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
if (existingLengths != null) {
//存在则说明同一任务的另一次尝试已经成功维护了索引文件,则需要将临时索引文件和临时
//数据文件删除即可
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
if (dataTmp != null && dataTmp.exists()) {
dataTmp.delete()
}
indexTmp.delete()
} else {
//这是为该task的第一次成功尝试,直接使用现有的索引和数据文件
//索引文件存在,则删除
if (indexFile.exists()) {
indexFile.delete()
}
//数据文件则删除
if (dataFile.exists()) {
dataFile.delete()
}
//将索引临时文件重命名为索引文件
if (!indexTmp.renameTo(indexFile)) {
throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
}
//将数据文件删除,将现有的数据临时文件重命名为数据文件
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
}
}
}
} finally {
if (indexTmp.exists() && !indexTmp.delete()) {
logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
}
}
}
五、Shuffle写之UnsafeShuffleWriter
未完待续。
六、Shuffle写之BypassMergeSortShuffleWriter
未完待续。
由于对spark理解有限,中间难免会有错误,还请各位指正,共同讨论学习。后续随着对spark理解的深入,会继续修改文章。