8.Executor源码分析与Task源码分析

9 篇文章 0 订阅
先来一张分析图:
CoarseGrainedExecutorBackend是在worker上启动的 , 是一个worker上的后台进程 , 启动之后会获取Driver的actor立即向Driver发送注册Executor的消息 , 注册成功之后Driver又会向CoarseGrainedExecutorBackend返回注册成功的RegisterExecutor消息 , 此时就会正式的创建Executor对象  , 等待Driver发送launchTask消息(TaskScheduler分配task完时就会发送这个消息) 

接收到launchTask消息后 , 由于Executor已经有了task , 就会对task继续反序列化 , 获取task里面的信息之后就会让Executor调用launchTask()方法 , 然后将task封装成taskRunner , 该runner就是实现了Java中的Runnable接口 , 将该对象放入一个ConcurrentHashMap集合中 , 随后调用Java的线程池对象(ThreadPoolExecutor)执行这个TaskRunner



从CoarseGrainedExecutorBackend的preStart()方法开始 , 源码如下:
    
    
  1. /**
  2. * 在初始化方法中
  3. */
  4. override def preStart() {
  5. logInfo("Connecting to driver: " + driverUrl)
  6. // 获取了driver的actor
  7. driver = context.actorSelection(driverUrl)
  8. // 向driver发送RegisterExecutor消息
  9. driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
  10. context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
  11. }
通过Driver的actor向driver发送一个需要注册Executor的消息 , 若发送成功Driver会返回一个可以注册的消息RegisteredExecutor . 源码如下:
     
     
  1. // driver注册executor成功之后会发送回来RegisteredExecutor消息
  2. // 此时CoarseGrainedExecutorBackend会创建Executor对象,作为执行句柄
  3. // 其实它的大部分功能都是executor实现的
  4. case RegisteredExecutor =>
  5. logInfo("Successfully registered with driver")
  6. val (hostname, _) = Utils.parseHostPort(hostPort)
  7. executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
上面的代码中就是创建一个executor了 , 接下来就是等待driver发送LaunchTask的消息 , 若是接收到LaunchTask的消息之后就会执行如下代码:
     
     
  1. // 启动task
  2. case LaunchTask(data) =>
  3. if (executor == null) {
  4. logError("Received LaunchTask command but executor was null")
  5. System.exit(1)
  6. } else {
  7. // 反序列化task
  8. val ser = env.closureSerializer.newInstance()
  9. val taskDesc = ser.deserialize[TaskDescription](data.value)
  10. logInfo("Got assigned task " + taskDesc.taskId)
  11. // 用内部的执行句柄executor的launchTask方法来启动一个task
  12. executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
  13. taskDesc.name, taskDesc.serializedTask)
  14. }
其中的executor的LaunchTask就是执行task任务的关键代码了  , 当然需要将接收到的task信息进行反序列化 , 然后就是在executor中执行task的源码了:
     
     
  1. def launchTask(
  2. context: ExecutorBackend,
  3. taskId: Long,
  4. attemptNumber: Int,
  5. taskName: String,
  6. serializedTask: ByteBuffer) {
  7. // 对于每一个task都会创建一个TaskRunner , TaskRunner继承的是Java的多线程的Runnable接口
  8. val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
  9. serializedTask)
  10. // 将TaskRunner放入内存缓存
  11. runningTasks.put(taskId, tr)
  12. // 通过Java的线程池ThreadPool执行taskRunner
  13. // Executor中有一个线程池,task被封装在TaskRunner中,直接将TaskRunner丢入线程池进行执行,然而线程池是自动实现了排队机制的
  14. // 因此若是线程池里面的任务都没有空闲的那么新丢进来的TaskRunner是等待的
  15. threadPool.execute(tr)
  16. }
从上面的代码中可以看出task信息是被封装在了一个叫做TaskRunner的对象中 , 而这个TaskRunner集成的就是Java中的Runnable接口 ,也就是线程 , 然后会将这个taskRunner放入Java中的线程池中进行运行 , 同时也会将这个TaskRunner放入缓存中 .
以上就是Executor的源码分析 , 比较简单 !

接下来就是Task的源码分析了 , 先上一张原理图:
 
既然Executor中用线程池对象执行一个TaskRunner , 那么TaskRunner中的run方法就是运行task的核心代码 , 而这个run方法中的代码有点多 , 我分为几个步骤来讲解:
1.通过网络加载相应的文件,资源,jar包等信息 , 源码如下:
     
     
  1. // 对序列化的task数据进行反序列化
  2. val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
  3. /// 然后通过网络通信将需要的文件,资源,jar包拷贝过来
  4. updateDependencies(taskFiles, taskJars)
  5. // 最后通过正式的反序列化操作将整个task的数据反序列化回来
  6. // 这里用到了Java的类加载器ClassLoader , 可以通过反射的方式动态加载一个类然后实例化对象,还可以用于指定上下文的相关资源进行加载和读取
  7. task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
上面的代码中的updateDependencies方法可以详细读一下 , 源码如下:
     
     
  1. private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
  2. // 获取Hadoop配置文件
  3. lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
  4. // 这里是用Java的synchronzied进行了多线程并发访问的同步,因为task实际上是以java线程的方式在一个CoarseGrainedExecutorBackend进程内并发运行的
  5. // 如果在执行业务逻辑的时候要访问一些共享的资源那么就可能出现多线程并发访问的安全问题
  6. // 所以spark选择synchronized进行了多线程并发访问的同步
  7. synchronized {
  8. // Fetch missing dependencies
  9. // 遍历要拉取的文件
  10. for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
  11. logInfo("Fetching " + name + " with timestamp " + timestamp)
  12. // Fetch file with useCache mode, close cache for local mode.
  13. // 通过fetchFile的网络通信从远程拉取文件
  14. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
  15. env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
  16. currentFiles(name) = timestamp
  17. }
  18. // 遍历要拉取的jar
  19. for ((name, timestamp) <- newJars) {
  20. // 判断了一下时间戳 , 要求当前的时间戳必须小于目标时间戳
  21. val localName = name.split("/").last
  22. val currentTimeStamp = currentJars.get(name)
  23. .orElse(currentJars.get(localName))
  24. .getOrElse(-1L)
  25. if (currentTimeStamp < timestamp) {
  26. logInfo("Fetching " + name + " with timestamp " + timestamp)
  27. // Fetch file with useCache mode, close cache for local mode.
  28. // 通过fetchFile()方法拉取jar文件
  29. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
  30. env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
  31. currentJars(name) = timestamp
  32. // Add it to our class loader
  33. val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
  34. if (!urlClassLoader.getURLs.contains(url)) {
  35. logInfo("Adding " + url + " to class loader")
  36. urlClassLoader.addURL(url)
  37. }
  38. }
  39. }
  40. }
  41. }

2. 运行task相关代码 , 执行我们编写程序中自定义的函数或者方法 , 源码如下:
     
     
  1. // 计算出task的开始时间
  2. taskStart = System.currentTimeMillis()
  3. // 关键步骤: 执行task用的是Task的run方法
  4. // value其实就是MapStatus,封装了ShuffleMapTask计算的数据
  5. // 若是后面还要一个ShuffleMapTask的话那么就会去联系MapOutputTracker来获取上一个ShuffleMapTask的输出位置,然后通过网络拉取数据
  6. // ResultTask也是一样的原理
  7. val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
  8. // 计算出task的结束时间
  9. val taskFinish = System.currentTimeMillis()
重点的代码是这个task.run方法 , 我们跟进去看一下 :
     
     
  1. final def run(taskAttemptId: Long, attemptNumber: Int): T = {
  2. // 创建一个TaskContext表示task的执行上下文 , 里面记录了task执行的一些全局性的数据
  3. // 比如task重试了几次 , 包括task属于哪个stage , task要处理的是rdd的哪个partition等等
  4. context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
  5. taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
  6. TaskContextHelper.setTaskContext(context)
  7. context.taskMetrics.setHostname(Utils.localHostName())
  8. taskThread = Thread.currentThread()
  9. if (_killed) {
  10. kill(interruptThread = false)
  11. }
  12. try {
  13. // 调用抽象方法runTask
  14. runTask(context)
  15. } finally {
  16. context.markTaskCompleted()
  17. TaskContextHelper.unset()
  18. }
  19. }
上面的run方法是Task.scala代码中的 , 里面继续调用runTask方法 , 而这个runTask方法在Task类中是一个抽象方法 , 当调用这个方法的时候根据调用对象的具体类型调用子类的runTask方法 , 而这里的Task有两种,一种是ShuffleMapTask,另一种是ResultTask , 而在前面的分析中ResultTask只有在最后一个stage中才会产生 , 这里先将ShuffleMapTask的runTask发发发:
源码如下:
     
     
  1. /**
  2. * 灰常重要的有点就是ShuffleMapTask的runTask方法有MapStatus返回值
  3. */
  4. override def runTask(context: TaskContext): MapStatus = {
  5. // Deserialize the RDD using the broadcast variable.
  6. // 对Task要处理的RDD相关的数据做一些反序列化操作
  7. // 多个Task运行在多个executor中,而且都是并行运行,可能都不在一个地方,因此task是怎么拿到自己处理的rdd的数据呢?
  8. // 这里其实是会通过broadcast variable 直接拿到
  9. val ser = SparkEnv.get.closureSerializer.newInstance()
  10. val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
  11. ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  12. metrics = Some(context.taskMetrics)
  13. var writer: ShuffleWriter[Any, Any] = null
  14. try {
  15. // 获取ShuffleManager ,从ShuffleManager中获取ShuffleWriter
  16. val manager = SparkEnv.get.shuffleManager
  17. writer = manager.[Any, Any](dep.shuffleHandle, partitionId, context)
  18. // 这里最重要 : 首先调用rdd的iterator方法并且传入了当前task要处理哪个partition
  19. // 核心的逻辑就在rdd的iterator方法中,实现了针对rdd的某个partition,执行我们自己定义的算子或者函数
  20. // 执行完了我们自己定义的算子或者函数其实就是相当于对rdd中的partition执行了处理,处理完了是会有返回的数据的
  21. // 返回的数据都是通过ShuffleWriter经过HashPartition进行分区之后写入自己对应的分区bucket
  22. writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  23. // 最后返回结果MapStatus,MapStatus里面封装了ShuffleMapTask计算后的数据 , 其实就是BlockManager相关的信息
  24. // BlockManager是spark底层的内存 数据 磁盘的数据管理的组件
  25. // Shuffle之后就是BlockManager的源码分析
  26. return writer.stop(success = true).get
  27. } catch {
  28. case e: Exception =>
  29. try {
  30. if (writer != null) {
  31. writer.stop(success = false)
  32. }
  33. } catch {
  34. case e: Exception =>
  35. log.debug("Could not stop writer", e)
  36. }
  37. throw e
  38. }
  39. }
上面的代码中首先一点就是task去拿自己计算的那份数据是通过Broadcast获取的 , 第二点就是将这份partition数据通过rdd的迭代执行我们自己定义的算子操作
最后结果存储在MapStatus中进行返回 , 对于rdd的iterator方法我们深入一下:
     
     
  1. /**
  2. * 这里的f函数其实就是我们定义的算子和函数执行的地方,但是Spark内部进行了封装还实现了一些其他的逻辑
  3. * 调用到这里为止其实就是在针对rdd的partition执行自定义的计算操作并返回新的rdd的partition的数据
  4. */
  5. override def compute(split: Partition, context: TaskContext) =
  6. f(context, split.index, firstParent[T].iterator(split, context))
  7. }


而ResultTask的runTask就比较简单了 , 源码如下:
     
     
  1. /**
  2. * ResultTask的runTask方法就比较简单了
  3. */
  4. override def runTask(context: TaskContext): U = {
  5. // Deserialize the RDD and the func using the broadcast variables.
  6. // 进行了基本的反序列化
  7. val ser = SparkEnv.get.closureSerializer.newInstance()
  8. val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
  9. ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  10. metrics = Some(context.taskMetrics)
  11. // 执行通过rdd的iterator , 执行我们定义的算子和函数
  12. func(context, rdd.iterator(partition, context))
  13. }

接下来TaskRunner中run方法的第三步 , 将MapStatus数据发送给Driver  并将一些信息显示在4040端口上:
     
     
  1. // 下面的操作其实就是对MapStatus进行了各种序列化和封装,因为后面要发送给Driver(通过网络)
  2. val resultSer = env.serializer.newInstance()
  3. val beforeSerialization = System.currentTimeMillis()
  4. val valueBytes = resultSer.serialize(value)
  5. val afterSerialization = System.currentTimeMillis()
  6. // 计算出了task相关的metrics,统计信息,包括运行了多少时间,反序列化消耗了多长时间
  7. // java虚拟机gc消耗了多长时间,结果的序列化消耗了多长时间
  8. // 这些都会显示在sparkUI上,在4040端口即可查看
  9. for (m <- task.metrics) {
  10. m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
  11. m.setExecutorRunTime(taskFinish - taskStart)
  12. m.setJvmGCTime(gcTime - startGCTime)
  13. m.setResultSerializationTime(afterSerialization - beforeSerialization)
  14. }
  15. val accumUpdates = Accumulators.values
  16. val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
  17. val serializedDirectResult = ser.serialize(directResult)
  18. val resultSize = serializedDirectResult.limit


第四步 , 通知状态改变 , Task执行结束的事件:
     
     
  1. // 这个很重要(核心),调用了Executor所在的CoarseGrainedExecutorBackend的statusUpdate方法
  2. execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
      
      
  1. /**
  2. * 这里会发送StatueUpdate消息给SparkDeploySchedulerBackend
  3. */
  4. override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
  5. driver ! StatusUpdate(executorId, taskId, state, data)
  6. }

CoarseGrainedSchedulerBackend处理状态改变的消息:
     
     
  1. /**
  2. * 处理Task执行结束的事件
  3. */
  4. case StatusUpdate(executorId, taskId, state, data) =>
  5. // 这里调用TaskSchedulerImpl的statusUpdate方法
  6. scheduler.statusUpdate(taskId, state, data.value)
  7. if (TaskState.isFinished(state)) {
  8. executorDataMap.get(executorId) match {
  9. case Some(executorInfo) =>
  10. executorInfo.freeCores += scheduler.CPUS_PER_TASK
  11. makeOffers(executorId)
  12. case None =>
  13. // Ignoring the update since we don't know about the executor.
  14. logWarning(s"Ignored task status update ($taskId state $state) " +
  15. "from unknown executor $sender with ID $executorId")
  16. }
  17. }


TaskScheduleImp处理task状态改变的消息:
     
     
  1. /**
  2. * 处理状态改变的操作
  3. */
  4. def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
  5. var failedExecutor: Option[String] = None
  6. synchronized {
  7. try {
  8. // 判断如果task是lost了,说明在实际编写spark程序的时候可能发现task lost了
  9. // 这个时候就是因为各种各样的原因执行失败
  10. if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
  11. // We lost this entire executor, so remember that it's gone
  12. val execId = taskIdToExecutorId(tid)
  13. // 这里就会移除Executor,将它加入失败队列
  14. if (activeExecutorIds.contains(execId)) {
  15. removeExecutor(execId)
  16. failedExecutor = Some(execId)
  17. }
  18. }
  19. taskIdToTaskSetId.get(tid) match {
  20. // 获取对应的TaskSet
  21. case Some(taskSetId) =>
  22. // 如果task结束了,从内存缓存中移除
  23. if (TaskState.isFinished(state)) {
  24. taskIdToTaskSetId.remove(tid)
  25. taskIdToExecutorId.remove(tid)
  26. }
  27. // 如果正常结束也做相应的处理
  28. activeTaskSets.get(taskSetId).foreach { taskSet =>
  29. if (state == TaskState.FINISHED) {
  30. taskSet.removeRunningTask(tid)
  31. taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
  32. } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
  33. taskSet.removeRunningTask(tid)
  34. taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
  35. }
  36. }
  37. case None =>
  38. logError(
  39. ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
  40. "likely the result of receiving duplicate task finished status updates)")
  41. .format(state, tid))
  42. }
  43. } catch {
  44. case e: Exception => logError("Exception in statusUpdate", e)
  45. }
  46. }
  47. // Update the DAGScheduler without holding a lock on this, since that can deadlock
  48. if (failedExecutor.isDefined) {
  49. dagScheduler.executorLost(failedExecutor.get)
  50. backend.reviveOffers()
  51. }
  52. }

以上就是所有的Executor源码分析和Task源码分析



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值