spark join shuffle 数据读取的过程
在spark中,当数据要shuffle时,这个拉取过程RDD是怎么和ShuffleMapTask 关联起来的。
在CoGroupedRDD通过调用如下函数去读取指定分区的数据
SparkEnv.get.shuffleManager
.getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
.read()
通过上面的方法,就可以知道调用那个依赖的RDD,读取那个分片数据。
然后创建BlockStoreShuffleReader读取对象。在该类中执行下面的方法
// 下面就是对这个shuffler中的分片数据进行读取并进行相关的aggregate操作了
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
可以看到首先要通过mapOutputTracker去拉取该分区的地址信息
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 拉取这些状态数据回来了
val statuses = getStatuses(shuffleId)
// Synchronize on the returned array because, on the driver, it gets mutated in place
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}
然后在 getStatuses函数中,发起远程调用,读取这个shuffle的结果地址数据
try {
// 拉取这个shuffle的状态数据
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
// 这个status是那些数据分片的地址来的
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
}
在MapOutputTrackerMaster中的MapOutputTrackerMasterEndpoint 接收线程中,接收到相关的消息
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
// 问这个shuffler的地址信息
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
// 去这个tracker里面去拉取了
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.length
if (serializedSize > maxAkkaFrameSize) {