文章目录
Spark源码剖析——Action操作、runJob流程
当前环境与版本
环境 | 版本 |
---|---|
JDK | java version “1.8.0_231” (HotSpot) |
Scala | Scala-2.11.12 |
Spark | spark-2.4.4 |
前言
- 在前面SparkSubmit提交流程中,我们已经讨论了一个Spark应用的提交流程、申请Application、启动Driver、启动Excutor,最后到Driver反射调用用户编写的类的main方法。
- 接着,在SparkContext实例化中,我们大致的看了用户代码中的SparkContext实例化的过程。
- 在本篇中,主要讨论在SparkContext实例化后,接着对用户代码的处理,一个Action操作是如何提交任务的(runJob)。
- 除此之外的,也可以看看一个简易的实现 实现链式、惰性特点的容器
- 我做了一幅触发Action的流程示意图(collect),如下
供分析的代码
- 下面我们来看一份简单的示例代码(词频统计)
object MySparkApp { def main(args: Array[String]): Unit = { val conf = new SparkConf() .setAppName("MySparkApp") val spark = SparkSession.builder() .config(conf) .getOrCreate() // 数据源(csv,逗号分隔) val srcRDD = spark.sparkContext.textFile("/test.csv") // 词频统计 val resultRDD = srcRDD.flatMap(_.split(",")) .map((_, 1)) .reduceByKey(_ + _) // 触发Action val resultArr = resultRDD.collect() // 省略其他操作 // ... spark.stop() } }
- 整个代码较为简单,构建SparkSession,读取原始数据,进行词频统计(Shuffle),最后再利用
collect
汇总数据到Driver端。 - 我们主要需要看的是利用
collect
触发Action的代码。
collect 源码分析
org.apache.spark.rdd.RDD
、org.apache.spark.SparkContext
Ctrl + 鼠标左键
点击collect
来到RDD的源码中(org.apache.spark.rdd.RDD
),代码如下。def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) Array.concat(results: _*) }
- 此处比较简单,调用
sc.runJob(...)
,然后返回一个Array。继续多点击几次runJob,我们来到SparkContext的2077~2084行,代码如下。def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int]): Array[U] = { val results = new Array[U](partitions.size) runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) results }
- 需要注意的是,此处将Array利用函数
(index, res) => results(index) = res
传入了runJob
方法,后面会对其进行赋值,这样就可以返回结果数组results了。再继续点击runJob
方法,来到SparkContext的2047~2064行,代码如下。def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit): Unit = { if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite // cleanedFunc用于确认闭包可序列化,防止func中存在不可序列化的情况 val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) if (conf.getBoolean("spark.logLineage", false)) { logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) } // 将任务交由DAGScheduler处理 dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) // 使用命令行提交时显示的进度条,结束 progressBar.foreach(_.finishAll()) rdd.doCheckpoint() }
- 此部分代码中,最重要的是调用
dagScheduler.runJob(...)
,将任务提交到了DAGScheduler。需要注意的是resultHandler
,这个函数会一直往下传。
DAGScheduler中的处理
org.apache.spark.scheduler.DAGScheduler
- 接着我来看DAGScheduler中的
runJob
方法,代码如下。def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime // 关键 val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) // 阻塞等待任务完成 ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) waiter.completionFuture.value.get match { case scala.util.Success(_) => // 成功 logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) case scala.util.Failure(exception) => // 失败 logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. val callerStackTrace = Thread.currentThread().getStackTrace.tail exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } }
- 此处利用
submitJob
构建了一个JobWaiter,然后进行了阻塞等待。再来到submitJob(...)
的代码中,如下。def submitJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // 检查分区,确保不在不存在的分区上启动任务 val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => throw new IllegalArgumentException( "Attempting to access a non-existent partition: " + p + ". " + "Total number of partitions: " + maxPartitions) } // 任务id val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { // 分区为0,说明不需要运行任务 return new JobWaiter[U](this, jobId, 0, resultHandler) } assert(partitions.size > 0) // 重点,封装 JobSubmitted,并提交到队列 val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter }
- 此部分代码的关键在最后几行,封装了
JobSubmitted
,并利用eventProcessLoop
将其提交到了队列。 - 有兴趣的朋友可以看看这个
EventLoop
(此处的实现类是DAGSchedulerEventProcessLoop
),其内部有一个LinkedBlockingDeque队列,启动后,会有一个守护线程不断轮询队列,取出元素,并调用onReceive
进行处理。在此处,则是最终调用到了DAGScheduler的doOnReceive
方法,匹配到JobSubmitted
。private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) case ... 省略 ...
- 另外,你也可以按照之前在Master、Worker启动流程中提到的快速查看技巧,快速定位到是什么地方接收到了
JobSubmitted
。 - 接着,我们来看DAGScheduler的
dagScheduler.handleJobSubmitted(...)
方法,代码如下。private[scheduler] def handleJobSubmitted(jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], callSite: CallSite, listener: JobListener, properties: Properties) { var finalStage: ResultStage = null try { // 解析划分Stage,根据ShuffleDependency // 此处返回的是最后一个Stage,即ResultStage finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) } catch { // 异常处理,省略... } // Job submitted, clear internal data. barrierJobIdToNumTasksCheckFailures.remove(jobId) // 封装Job相关信息 val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) clearCacheLocs() // 省略部分代码... val jobSubmissionTime = clock.getTimeMillis() jobIdToActiveJob(jobId) = job activeJobs += job finalStage.setActiveJob(job) val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) // 提交消息到listenerBus,方便UI界面查看到任务提交 listenerBus.post( SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) // 最后,提交Stage submitStage(finalStage) }
- 此处,有两个关键点
createResultStage(...)
根据宽窄依赖划分了Stage(ShuffleMapStage、ResultStage),后续再来讲该部分代码,你也可以自己看看(递归较麻烦)- 调用
submitStage(...)
提交Stage
- 再来看
submitStage(...)
方法private def submitStage(stage: Stage) { val jobId = activeJobForStage(stage) if (jobId.isDefined) { logDebug("submitStage(" + stage + ")") // Stage:不是正在等待的、不是正在运行的、不是失败的 if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { // 往上找Stage,因为最开始提交进来的是最后一个Stage val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") // 上面没有Stage了,那就正式提交Stage了 submitMissingTasks(stage, jobId.get) } else { for (parent <- missing) { // 递归 submitStage(parent) } waitingStages += stage } } } else { abortStage(stage, "No active job for stage " + stage.id, None) } }
- 因为最开始解析Stage后,返回的是最后一个Stage,因此需要递归往上找到最前面的Stage,再提交Stage。
- 我们来看看是DAGScheduler如何提交任务的,方法
submitMissingTasks(...)
较长(1083~1232行),我们来看其中的关键点。 - 调用
getPreferredLocs
,计算出Task的最佳位置(1105~1124行)// 利用getPreferredLocs获取最优的处理位置 val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { stage match { case s: ShuffleMapStage => // 如果是ShuffleMapStage partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap case s: ResultStage => // 如果是ResultStage partitionsToCompute.map { id => val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } } catch { // 省略代码 } stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
- 广播(1140~1176行)
var taskBinary: Broadcast[Array[Byte]] = null var partitions: Array[Partition] = null try { var taskBinaryBytes: Array[Byte] = null RDDCheckpointData.synchronized { taskBinaryBytes = stage match { case stage: ShuffleMapStage => // rdd, shuffleDep JavaUtils.bufferToArray( closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => // rdd, func JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) } catch { // 省略代码 }
- 序列化Task(1178~1208行)
val tasks: Seq[Task[_]] = try { val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array() stage match { case stage: ShuffleMapStage => // 生成ShuffleMapTask,附带之前的广播taskBinary stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = partitions(id) stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } case stage: ResultStage => // 生成ResultTask,附带之前的广播taskBinary partitionsToCompute.map { id => val p: Int = stage.partitions(id) val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } } } catch { // 省略代码 }
- 封装TaskSet,利用TaskScheduler提交任务(1210~1231行)
// 大于0,说明有任务 if (tasks.size > 0) { logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") // 利用TaskScheduler提交任务 taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // 对于没有Task的处理 // 省略部分代码 }
TaskScheduler中的处理
org.apache.spark.scheduler.TaskSchedulerImpl
- 我们来看TaskScheduler中的处理,
submitTasks(...)
方法如下。override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { // 将TaskSet转换为TaskSetManager val manager = createTaskSetManager(taskSet, maxTaskFailures) val stage = taskSet.stageId val stageTaskSets = taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) stageTaskSets.foreach { case (_, ts) => ts.isZombie = true } stageTaskSets(taskSet.stageAttemptId) = manager // 将TaskSetManager提交到任务调度的Pool中,包括FIFO、Fair两种 schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) // 启动定时器,检查任务是否已经运行了 if (!isLocal && !hasReceivedTask) { starvationTimer.scheduleAtFixedRate(new TimerTask() { override def run() { if (!hasLaunchedTask) { // 这段是平时提交任务后比较常见的日志(如果集群资源不够的话) logWarning("Initial job has not accepted any resources; " + "check your cluster UI to ensure that workers are registered " + "and have sufficient resources") } else { this.cancel() } } }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) } hasReceivedTask = true } // 重要,利用SchedulerBackend发消息给Driver类 backend.reviveOffers() }
- 此部分代码主要做了两件事
- 将TaskSet转为TaskSetManager,并提交至了任务调度的Pool中
- 利用SchedulerBackend发消息给Driver类,使其处理Pool中的任务
- SchedulerBackend有多个实现类,后面我们用CoarseGrainedSchedulerBackend做示例
CoarseGrainedSchedulerBackend、DriverEndpoint中的处理
org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.DriverEndpoint
- CoarseGrainedSchedulerBackend调用
reviveOffers()
后,会利用driverEndpoint的Ref向Driver发送一个消息ReviveOffers
- 如果你还不熟悉RpcEndpoint、RpcEnv的话,可以利用在Master、Worker启动流程中提到的快速查看技巧,可以快速定位到DriverEndpoint(其实就在CoarseGrainedSchedulerBackend中,是个内部类)的
receive
方法,代码如下。override def receive: PartialFunction[Any, Unit] = { case 省略代码... case ReviveOffers => makeOffers() case 省略代码... }
- 接着,再看
makeOffers()
方法private def makeOffers() { // 确保待启动Task的Executor没问题 val taskDescs = withLock { val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores, Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } // OK,没问题,那么启动Task if (!taskDescs.isEmpty) { launchTasks(taskDescs) } }
- 这部分代码应该没什么问题,我们接着往下看
launchTasks(...)
方法private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { // 循环处理Task for (task <- tasks.flatten) { val serializedTask = TaskDescription.encode(task) if (serializedTask.limit() >= maxRpcMessageSize) { // 如果超过了最大的消息限制,就发出提示 // 省略代码 } else { // 更新executor信息 val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + s"${executorData.executorHost}.") // 将序列化的Task封装为LaunchTask // 向Executor发送启动任务的消息 executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } }
Executor中的处理
org.apache.spark.executor.CoarseGrainedExecutorBackend
- 接着再利用前面多次提到的快速查看技巧,可以定位到Executor处的CoarseGrainedExecutorBackend的
receive
方法。override def receive: PartialFunction[Any, Unit] = { case 省略代码... case LaunchTask(data) => if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskDesc) } case 省略代码...
- 这样,最终就会调用Executor的
launchTask
方法处理Task了。def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { // 封装一个TaskRunner val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) // 提交到线程池中 threadPool.execute(tr) }
- 想继续的朋友,可以再看TaskRunner的
run
方法 ^_^