源码-stage->task->taskSet->executor

本文:接《DAGScheduler及Stage划分提交》分析Stage中得Task是如何生成并且最终提交到Executor中去的。

从org.apache.spark.scheduler.DAGScheduler#submitMissingTasks开始,分析Stage是如何生成TaskSet的。

 private def submitMissingTasks(stage: Stage, jobId: Int) {
    logDebug("submitMissingTasks(" + stage + ")")
    // Get our pending tasks and remember them in our pendingTasks entry
    stage.pendingPartitions.clear()

    // First figure out the indexes of partition ids to compute.
    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()

    // Create internal accumulators if the stage has no accumulators initialized.
    // Reset internal accumulators only if this stage is not partially submitted
    // Otherwise, we may override existing accumulator values from some tasks
    if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
      stage.resetInternalAccumulators()
    }

    // Use the scheduling pool, job group, description, etc. from an ActiveJob associated
    // with this Stage
    val properties = jobIdToActiveJob(jobId).properties

    runningStages += stage
    // SparkListenerStageSubmitted should be posted before testing whether tasks are
    // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
    // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
    // event.
    stage match {
      case s: ShuffleMapStage =>
        outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
      case s: ResultStage =>
        outputCommitCoordinator.stageStart(
          stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
    }
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          val job = s.activeJob.get
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
      case NonFatal(e) =>
        stage.makeNewStageAttempt(partitionsToCompute.size)
        listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
        abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
    listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

    // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
    // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
    // the serialized copy of the RDD and for each task we will deserialize it, which means each
    // task gets a different copy of the RDD. This provides stronger isolation between tasks that
    // might modify state of objects referenced in their closures. This is necessary in Hadoop
    // where the JobConf/Configuration object is not thread-safe.
    var taskBinary: Broadcast[Array[Byte]] = null
    try {
      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
      // For ResultTask, serialize and broadcast (rdd, func).
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
        case stage: ResultStage =>
          closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
      }

      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
      // In the case of a failure during serialization, abort the stage.
      case e: NotSerializableException =>
        abortStage(stage, "Task not serializable: " + e.toString, Some(e))
        runningStages -= stage

        // Abort execution
        return
      case NonFatal(e) =>
        abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    val tasks: Seq[Task[_]] = try {
      stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.internalAccumulators)
          }

        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, stage.internalAccumulators)
          }
      }
    } catch {
      case NonFatal(e) =>
        abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
        runningStages -= stage
        return
    }

    if (tasks.size > 0) {
      logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
      stage.pendingPartitions ++= tasks.map(_.partitionId)
      logDebug("New pending partitions: " + stage.pendingPartitions)
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())

    } else {
      // Because we posted SparkListenerStageSubmitted earlier, we should mark
      // the stage as completed here in case there are no tasks to run
      markStageAsFinished(stage, None)

      val debugString = stage match {
        case stage: ShuffleMapStage =>
          s"Stage ${stage} is actually done; " +
            s"(available: ${stage.isAvailable}," +
            s"available outputs: ${stage.numAvailableOutputs}," +
            s"partitions: ${stage.numPartitions})"
        case stage : ResultStage =>
          s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
      }
      logDebug(debugString)
    }
  }

org.apache.spark.scheduler.DAGScheduler#submitMissingTasks的计算流程如下:

  1. 首先得到RDD中需要计算的partition,对于Shuffle类型的stage,需要判断stage中是否缓存了该结果;对于Result类型的Final Stage,则判断计算Job中该partition是否已经计算完成。
  2. 序列化task的binary。Executor可以通过广播变量得到它。每个task运行的时候首先会反序列化。这样在不同的executor上运行的task是隔离的,不会相互影响。
  3. 为每个需要计算的partition生成一个task:对于Shuffle类型依赖的Stage,生成ShuffleMapTask类型的task;对于Result类型的Stage,生成一个ResultTask类型的task
  4. 确保Task是可以被序列化的。因为不同的cluster有不同的taskScheduler,在这里判断可以简化逻辑;保证TaskSet的task都是可以序列化的
  5. 通过TaskScheduler提交TaskSet。
TaskSet就是可以做pipeline的一组完全相同的task,每个task的处理逻辑完全相同,不同的是处理数据,每个task负责处理一个partition。pipeline,可以称为大数据处理的基石,只有数据进行pipeline处理,才能将其放到集群中去运行。对于一个task来说,它从数据源获得逻辑,然后按照拓扑顺序,顺序执行(实际上是调用rdd的compute)。
TaskSet是一个数据结构,存储了这一组task:
[java] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. private[spark] class TaskSet(  
  2.     val tasks: Array[Task[_]],  
  3.     val stageId: Int,  
  4.     val attempt: Int,  
  5.     val priority: Int,  
  6.     val properties: Properties) {  
  7.     val id: String = stageId + "." + attempt  
  8.   
  9.   override def toString: String = "TaskSet " + id  
  10. }  


管理调度这个TaskSet的时org.apache.spark.scheduler.TaskSetManager,TaskSetManager会负责task的失败重试;跟踪每个task的执行状态;处理locality-aware的调用。
详细的调用堆栈如下:
  1. org.apache.spark.scheduler.TaskSchedulerImpl#submitTasks
  2. org.apache.spark.scheduler.SchedulableBuilder#addTaskSetManager
  3. org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend#reviveOffers
(1) override def reviveOffers() {
    driverEndpoint.send(ReviveOffers)
  }

  我们继续看driverEndpoint是什么鬼。driverEndpoint是RPC中driver端Endpoint的引用,其类型为RpcEndpointRef。在CoarseGrainedSchedulerBackend启动时的start()方法中,对driverEndpoint进行了赋值:

  1. // TODO (prashant) send conf instead of properties  
  2.     driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))  
        这个RpcEnv只是一个抽象类,它有两种实现,一个是基于AKKa的AkkaRpcEnv,另外一个则是基于Netty的NettyRpcEnv,默认的实现是Netty。通过下述RpcEnv的代码即可看出:
  1. private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {  
  2.       
  3.     // 两种实现方式:  
  4.     // akka:org.apache.spark.rpc.akka.AkkaRpcEnvFactory  
  5.     // netty:org.apache.spark.rpc.netty.NettyRpcEnvFactory  
  6.     val rpcEnvNames = Map(  
  7.       "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",  
  8.       "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")  
  9.       
  10.     // 通过参数spark.rpc配置,默认为netty  
  11.     val rpcEnvName = conf.get("spark.rpc""netty")  
  12.     val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)  
  13.     Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]  
  14.   }  

        下面,我们就看下Netty的概要实现,在NettyRpcEnv的setupEndpoint()方法中:

  1. override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {  
  2.       
  3.     // 调用Dispatcher的registerRpcEndpoint()方法完成注册  
  4.     dispatcher.registerRpcEndpoint(name, endpoint)  
  5.   }  
        它是通过dispatcher来完成endpoint注册的,name为“CoarseGrainedScheduler”,RpcEndpoint为CoarseGrainedSchedulerBackend中通过createDriverEndpoint()方法创建的DriverEndpoint对象。代码如下:
  1. protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {  
  2.     new DriverEndpoint(rpcEnv, properties)  
  3.   }  
        那么这个DriverEndpoint是什么类呢?我们发现它继承自ThreadSafeRpcEndpoint,继而继承RpcEndpoint这个类。这里,我们只要知道这个RpcEndpoint是进程间消息传递调用的一个端点,定义了消息触发的函数。当一个消息到来时,方法调用顺序为  onStart, receive, onStop。它的生命周期为constructor -> onStart -> receive* -> onStop。

        为什么要用RpcEndpoint呢?很简单,Task的调度与执行是在一个分布式集群上进行的,自然需要进程间的通讯。

        继续分析,那么上面提到的driverEndpoint是如何赋值的呢?我们继续看Dispatcher的registerRpcEndpoint()方法,因为最终是由它向上返回RpcEndpointRef来完成driverEndpoint的赋值的。代码如下:

  1. // 注册RpcEndpoint  
  2.   // name为“Master”,endpoint为Master对象  
  3.   def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {  
  4.       
  5.     // 创建RpcEndpointAddress  
  6.     val addr = RpcEndpointAddress(nettyEnv.address, name)  
  7.       
  8.     // 创建NettyRpcEndpointRef  
  9.     val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)  
  10.       
  11.     // 同步代码块  
  12.     synchronized {  
  13.       if (stopped) {  
  14.         throw new IllegalStateException("RpcEnv has been stopped")  
  15.       }  
  16.         
  17.       // ConcurrentHashMap的putIfAbsent()方法确保不会重复创建EndpointData  
  18.       if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {  
  19.         throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")  
  20.       }  
  21.         
  22.         
  23.       val data = endpoints.get(name)  
  24.       endpointRefs.put(data.endpoint, data.ref)  
  25.       receivers.offer(data)  // for the OnStart message  
  26.     }  
  27.     endpointRef  
  28.   }  
        返回的RpcEndpointRef为NettyRpcEndpointRef类型,而RpcEndpointRef则是一个远程RpcEndpoint的引用,通过它可以给远程RpcEndpoint发送消息,可以是同步可以是异步,它映射一个地址。这么看来,我们在远端(ps:另外的机器或者进程)注册了一个RpcEndpoint,即DriverEndpoint,而在本地端(当前机器或者进程)则持有一个RpcEndpoint的引用,即NettyRpcEndpointRef,可以由它来往远端发送消息,那么发送的是什么消息呢?我们现在返回CoarseGrainedSchedulerBackend中的reviveOffers()方法,发现发送的是ReviveOffers消息。这里只是发送,具体处理还要看远端的RpcEndpoint,即DriverEndpoint。通过上面我们可以知道,RpcEndpoint的服务流程为onStart()-->receive()--> onStop(),每当消息来临时,DriverEndpoint都会调用receive()方法来处理。


(2)override def receive: PartialFunction[Any, Unit] = {

      case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                s"from unknown executor with ID $executorId")
          }
        }

      case ReviveOffers =>
        makeOffers()

3)   private def makeOffers() {
      // Filter out executors under killing
      val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
      val workOffers = activeExecutors.map { case (id, executorData) =>
        new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
      }.toSeq
      launchTasks(scheduler.resourceOffers(workOffers))--进入到TaskSchedulerImpl
    }

(4)TaskSchedulerImpl->resourceOffers

  def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
    // Mark each slave as alive and remember its hostname
    // Also track if new executor is added
    var newExecAvail = false
    for (o <- offers) {
      executorIdToHost(o.executorId) = o.host
      executorIdToTaskCount.getOrElseUpdate(o.executorId, 0)
      if (!executorsByHost.contains(o.host)) {
        executorsByHost(o.host) = new HashSet[String]()
        executorAdded(o.executorId, o.host)
        newExecAvail = true
      }
      for (rack <- getRackForHost(o.host)) {
        hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
      }
    }

    // Randomly shuffle offers to avoid always placing tasks on the same set of workers.
    val shuffledOffers = Random.shuffle(offers)
    // Build a list of tasks to assign to each worker.
    val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
    val availableCpus = shuffledOffers.map(o => o.cores).toArray
    val sortedTaskSets = rootPool.getSortedTaskSetQueue
    for (taskSet <- sortedTaskSets) {
      logDebug("parentName: %s, name: %s, runningTasks: %s".format(
        taskSet.parent.name, taskSet.name, taskSet.runningTasks))
      if (newExecAvail) {
        taskSet.executorAdded()
      }
    }

    // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
    // of locality levels so that it gets a chance to launch local tasks on all of them.
    // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
    var launchedTask = false
    for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
      do {
        launchedTask = resourceOfferSingleTaskSet(
            taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
      } while (launchedTask)
    }

    if (tasks.size > 0) {
      hasLaunchedTask = true
    }
    return tasks
  }


关于resourceOffers的详细介绍见:http://www.jianshu.com/p/9a059ace2f3a/comments/1495052(
【SubStep1】: executor, host, rack等信息更新)


(5) private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        val serializedTask = ser.serialize(task)
        if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
          scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
                "spark.akka.frameSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
                AkkaUtils.reservedSizeBytes)
              taskSetMgr.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          val executorData = executorDataMap(task.executorId)
          executorData.freeCores -= scheduler.CPUS_PER_TASK
          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

  launchTasks的执行逻辑很简单,针对传入的TaskDescription序列,循环每个Task,做以下处理:

        1、首先对Task进行序列化,得到serializedTask;

        2、针对序列化后的Task:serializedTask,判断其大小:

              2.1、序列化后的task的大小达到或超出规定的上限,即框架配置的Akka消息最大大小,减去除序列化task或task结果外,一个Akka消息需要保留的额外大小的值,则根据task的taskId,在TaskSchedulerImpl的taskIdToTaskSetManager中获取对应的TaskSetManager,并调用其abort()方法,标记对应TaskSetManager为失败;

              2.2、序列化后的task的大小未达到上限,在规定的大小范围内,则:

                       2.2.1、从executorDataMap中,根据task.executorId获取executor描述信息executorData;

                       2.2.2、在executorData中,freeCores做相应减少;

                       2.2.3、利用executorData中的executorEndpoint,即Driver端executor通讯端点的引用,发送LaunchTask事件,LaunchTask事件中包含序列化后的task,将Task传递到executor中去执行。那么executor中是如何接收LaunchTask事件的呢?答案就在CoarseGrainedExecutorBackend中。

  1. private[spark] class CoarseGrainedExecutorBackend(  
  2.     override val rpcEnv: RpcEnv,  
  3.     driverUrl: String,  
  4.     executorId: String,  
  5.     hostPort: String,  
  6.     cores: Int,  
  7.     userClassPath: Seq[URL],  
  8.     env: SparkEnv)  
  9.   extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {  
        由上面的代码我们可以知道,它实现了ThreadSafeRpcEndpoint和ExecutorBackend两个trait,而ExecutorBackend的定义如下:
  1. /** 
  2.  * A pluggable interface used by the Executor to send updates to the cluster scheduler. 
  3.  * 一个被Executor用来发送更新到集群调度器的可插拔接口。 
  4.  */  
  5. private[spark] trait ExecutorBackend {  
  6.     
  7.   // 唯一的一个statusUpdate()方法  
  8.   // 需要Long类型的taskId、TaskState类型的state、ByteBuffer类型的data三个参数  
  9.   def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)  
  10. }  

        那么它自然就有两种主要的任务,第一,作为endpoint提供driver与executor间的通讯功能;第二,提供了executor任务执行时状态汇报的功能。

        CoarseGrainedExecutorBackend到底是什么呢?这里我们先不深究,留到以后分析,你只要知道它是Executor的一个后台辅助进程,和Executor是一对一的关系,向Executor提供了与Driver通讯、任务执行时状态汇报两个基本功能即可。

        接下来,我们看下CoarseGrainedExecutorBackend是如何处理LaunchTask事件的。做为RpcEndpoint,在其处理各类事件或消息的receive()方法中,定义如下:

  1. case LaunchTask(data) =>  
  2.       if (executor == null) {  
  3.         logError("Received LaunchTask command but executor was null")  
  4.         System.exit(1)  
  5.       } else {  
  6.         
  7.         // 反序列话task,得到taskDesc  
  8.         val taskDesc = ser.deserialize[TaskDescription](data.value)  
  9.         logInfo("Got assigned task " + taskDesc.taskId)  
  10.           
  11.         // 调用executor的launchTask()方法加载task  
  12.         executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,  
  13.           taskDesc.name, taskDesc.serializedTask)  
  14.       }  
        首先,会判断对应的executor是否为空,为空的话,记录错误日志并退出,不为空的话,则按照如下流程处理:

        1、反序列话task,得到taskDesc;

        2、调用executor的launchTask()方法加载task。

        那么,重点就落在了Executor的launchTask()方法中,代码如下:

  1. def launchTask(  
  2.       context: ExecutorBackend,  
  3.       taskId: Long,  
  4.       attemptNumber: Int,  
  5.       taskName: String,  
  6.       serializedTask: ByteBuffer): Unit = {  
  7.         
  8.     // 新建一个TaskRunner  
  9.     val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,  
  10.       serializedTask)  
  11.         
  12.     // 将taskId与TaskRunner的对应关系存入runningTasks  
  13.     runningTasks.put(taskId, tr)  
  14.       
  15.     // 线程池执行TaskRunner  
  16.     threadPool.execute(tr)  
  17.   }  
        非常简单,创建一个TaskRunner对象,然后将taskId与TaskRunner的对应关系存入runningTasks,将TaskRunner扔到线程池中去执行即可。

        我们先看下这个TaskRunner类。我们先看下Class及其成员变量的定义,如下:

  1. class TaskRunner(  
  2.       execBackend: ExecutorBackend,  
  3.       val taskId: Long,  
  4.       val attemptNumber: Int,  
  5.       taskName: String,  
  6.       serializedTask: ByteBuffer)  
  7.     extends Runnable {  
  8.       
  9.     // TaskRunner继承了Runnable  
  10.   
  11.     /** Whether this task has been killed. */  
  12.     // 标志位,task是否被杀掉  
  13.     @volatile private var killed = false  
  14.   
  15.     /** How much the JVM process has spent in GC when the task starts to run. */  
  16.     @volatile var startGCTime: Long = _  
  17.   
  18.     /** 
  19.      * The task to run. This will be set in run() by deserializing the task binary coming 
  20.      * from the driver. Once it is set, it will never be changed. 
  21.      *  
  22.      * 需要运行的task。它将在反序列化来自driver的task二进制数据时在run()方法被设置,一旦被设置,它将不会再发生改变。 
  23.      */  
  24.     @volatile var task: Task[Any] = _  
  25. }  
        由类的定义我们可以看出,TaskRunner继承了Runnable,所以它本质上是一个线程,故其可以被放到线程池中去运行。它所包含的成员变量,主要有以下几个:

        1、execBackend:Executor后台辅助进程,提供了与Driver通讯、状态汇报等两大基本功能,实际上传入的是CoarseGrainedExecutorBackend实例;

        2、taskId:Task的唯一标识;

        3、attemptNumber:Task运行的序列号,Spark与MapReduce一样,可以为拖后腿任务启动备份任务,即推测执行原理,如此,就需要通过taskId加attemptNumber来唯一标识一个Task运行实例;

        4、serializedTask:ByteBuffer类型,序列化后的Task,包含的是Task的内容,通过发序列化它来得到Task,并运行其中的run()方法来执行Task;

        5、killed:Task是否被杀死的标志位;

        6、task:Task[Any]类型,需要运行的Task,它将在反序列化来自driver的task二进制数据时在run()方法被设置,一旦被设置,它将不会再发生改变;

       7、startGCTime:JVM在task开始运行后,进行垃圾回收的时间。

        另外,既然是一个线程,TaskRunner必须得提供run()方法,该run()方法就是TaskRunner线程在线程池中被调度时,需要执行的方法,我们来看下它的定义:

  1. override def run(): Unit = {  
  2.       
  3.       // Step1:Task及其运行时需要的辅助对象构造  
  4.         
  5.       // 获取任务内存管理器  
  6.       val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)  
  7.         
  8.       // 反序列化开始时间  
  9.       val deserializeStartTime = System.currentTimeMillis()  
  10.         
  11.       // 当前线程设置上下文类加载器  
  12.       Thread.currentThread.setContextClassLoader(replClassLoader)  
  13.         
  14.       // 从SparkEnv中获取序列化器  
  15.       val ser = env.closureSerializer.newInstance()  
  16.       logInfo(s"Running $taskName (TID $taskId)")  
  17.         
  18.       // execBackend更新状态TaskState.RUNNING  
  19.       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)  
  20.       var taskStart: Long = 0  
  21.         
  22.       // 计算垃圾回收的时间  
  23.       startGCTime = computeTotalGcTime()  
  24.   
  25.       try {  
  26.         // 调用Task的deserializeWithDependencies()方法,反序列化Task,得到Task运行需要的文件taskFiles、jar包taskFiles和Task二进制数据taskBytes  
  27.         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)  
  28.         updateDependencies(taskFiles, taskJars)  
  29.           
  30.         // 反序列化Task二进制数据taskBytes,得到task实例  
  31.         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)  
  32.           
  33.         // 设置Task的任务内存管理器  
  34.         task.setTaskMemoryManager(taskMemoryManager)  
  35.   
  36.         // If this task has been killed before we deserialized it, let's quit now. Otherwise,  
  37.         // continue executing the task.  
  38.         // 如果此时Task被kill,抛出异常,快速退出  
  39.         if (killed) {  
  40.           // Throw an exception rather than returning, because returning within a try{} block  
  41.           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl  
  42.           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure  
  43.           // for the task.  
  44.           throw new TaskKilledException  
  45.         }  
  46.   
  47.         logDebug("Task " + taskId + "'s epoch is " + task.epoch)  
  48.         // mapOutputTracker更新Epoch  
  49.         env.mapOutputTracker.updateEpoch(task.epoch)  
  50.   
  51.         // Run the actual task and measure its runtime.  
  52.         // 运行真正的task,并度量它的运行时间  
  53.           
  54.         // Step2:Task运行  
  55.           
  56.         // task开始时间  
  57.         taskStart = System.currentTimeMillis()  
  58.           
  59.         // 标志位threwException设置为true,标识Task真正执行过程中是否抛出异常  
  60.         var threwException = true  
  61.           
  62.         // 调用Task的run()方法,真正执行Task,并获得运行结果value  
  63.         val (value, accumUpdates) = try {  
  64.           
  65.           // 调用Task的run()方法,真正执行Task  
  66.           val res = task.run(  
  67.             taskAttemptId = taskId,  
  68.             attemptNumber = attemptNumber,  
  69.             metricsSystem = env.metricsSystem)  
  70.             
  71.           // 标志位threwException设置为false  
  72.           threwException = false  
  73.             
  74.           // 返回res,Task的run()方法中,res的定义为(T, AccumulatorUpdates)  
  75.           // 这里,前者为任务运行结果,后者为累加器更新  
  76.           res  
  77.         } finally {  
  78.             
  79.           // 通过任务内存管理器清理所有的分配的内存  
  80.           val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()  
  81.           if (freedMemory > 0) {  
  82.             val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"  
  83.             if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak"false) && !threwException) {  
  84.               throw new SparkException(errMsg)  
  85.             } else {  
  86.               logError(errMsg)  
  87.             }  
  88.           }  
  89.         }  
  90.           
  91.         // task完成时间  
  92.         val taskFinish = System.currentTimeMillis()  
  93.   
  94.         // If the task has been killed, let's fail it.  
  95.         // 如果task被杀死,抛出TaskKilledException异常  
  96.         if (task.killed) {  
  97.           throw new TaskKilledException  
  98.         }  
  99.   
  100.         // Step3:Task运行结果处理  
  101.           
  102.         // 通过Spark获取Task运行结果序列化器  
  103.         val resultSer = env.serializer.newInstance()  
  104.           
  105.         // 结果序列化前的时间点  
  106.         val beforeSerialization = System.currentTimeMillis()  
  107.           
  108.         // 利用Task运行结果序列化器序列化Task运行结果,得到valueBytes  
  109.         val valueBytes = resultSer.serialize(value)  
  110.           
  111.         // 结果序列化后的时间点  
  112.         val afterSerialization = System.currentTimeMillis()  
  113.   
  114.         // 度量指标体系相关,暂不介绍  
  115.         for (m <- task.metrics) {  
  116.           // Deserialization happens in two parts: first, we deserialize a Task object, which  
  117.           // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.  
  118.           m.setExecutorDeserializeTime(  
  119.             (taskStart - deserializeStartTime) + task.executorDeserializeTime)  
  120.           // We need to subtract Task.run()'s deserialization time to avoid double-counting  
  121.           m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)  
  122.           m.setJvmGCTime(computeTotalGcTime() - startGCTime)  
  123.           m.setResultSerializationTime(afterSerialization - beforeSerialization)  
  124.           m.updateAccumulators()  
  125.         }  
  126.   
  127.         // 构造DirectTaskResult,同时包含Task运行结果valueBytes和累加器更新值accumulator updates  
  128.         val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)  
  129.           
  130.         // 序列化DirectTaskResult,得到serializedDirectResult  
  131.         val serializedDirectResult = ser.serialize(directResult)  
  132.           
  133.         // 获取Task运行结果大小  
  134.         val resultSize = serializedDirectResult.limit  
  135.   
  136.         // directSend = sending directly back to the driver  
  137.         // directSend的意思就是直接发送结果至Driver端  
  138.         val serializedResult: ByteBuffer = {  
  139.           
  140.           // 如果Task运行结果大小大于所有Task运行结果的最大大小,序列化IndirectTaskResult  
  141.           // IndirectTaskResult为存储在Worker上BlockManager中DirectTaskResult的一个引用  
  142.           if (maxResultSize > 0 && resultSize > maxResultSize) {  
  143.             logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +  
  144.               s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +  
  145.               s"dropping it.")  
  146.             ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))  
  147.           }  
  148.           // 如果 Task运行结果大小超过Akka除去需要保留的字节外最大大小,则将结果写入BlockManager  
  149.           // 即运行结果无法通过消息传递  
  150.           else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {  
  151.               
  152.             val blockId = TaskResultBlockId(taskId)  
  153.             env.blockManager.putBytes(  
  154.               blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)  
  155.             logInfo(  
  156.               s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")  
  157.             ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))  
  158.           }   
  159.           // Task运行结果比较小的话,直接返回,通过消息传递  
  160.           else {  
  161.             logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")  
  162.             serializedDirectResult  
  163.           }  
  164.         }  
  165.   
  166.         // execBackend更新状态TaskState.FINISHED  
  167.         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)  
  168.   
  169.       } catch {// 处理各种异常信息  
  170.           
  171.         case ffe: FetchFailedException =>  
  172.           val reason = ffe.toTaskEndReason  
  173.           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))  
  174.   
  175.         case _: TaskKilledException | _: InterruptedException if task.killed =>  
  176.           logInfo(s"Executor killed $taskName (TID $taskId)")  
  177.           execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))  
  178.   
  179.         case cDE: CommitDeniedException =>  
  180.           val reason = cDE.toTaskEndReason  
  181.           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))  
  182.   
  183.         case t: Throwable =>  
  184.           // Attempt to exit cleanly by informing the driver of our failure.  
  185.           // If anything goes wrong (or this was a fatal exception), we will delegate to  
  186.           // the default uncaught exception handler, which will terminate the Executor.  
  187.           logError(s"Exception in $taskName (TID $taskId)", t)  
  188.   
  189.           val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>  
  190.             task.metrics.map { m =>  
  191.               m.setExecutorRunTime(System.currentTimeMillis() - taskStart)  
  192.               m.setJvmGCTime(computeTotalGcTime() - startGCTime)  
  193.               m.updateAccumulators()  
  194.               m  
  195.             }  
  196.           }  
  197.           val serializedTaskEndReason = {  
  198.             try {  
  199.               ser.serialize(new ExceptionFailure(t, metrics))  
  200.             } catch {  
  201.               case _: NotSerializableException =>  
  202.                 // t is not serializable so just send the stacktrace  
  203.                 ser.serialize(new ExceptionFailure(t, metrics, false))  
  204.             }  
  205.           }  
  206.             
  207.           // execBackend更新状态TaskState.FAILED  
  208.           execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)  
  209.   
  210.           // Don't forcibly exit unless the exception was inherently fatal, to avoid  
  211.           // stopping other tasks unnecessarily.  
  212.           if (Utils.isFatalError(t)) {  
  213.             SparkUncaughtExceptionHandler.uncaughtException(t)  
  214.           }  
  215.   
  216.       } finally {  
  217.         
  218.         // 最后,无论运行成功还是失败,将task从runningTasks中移除  
  219.         runningTasks.remove(taskId)  
  220.       }  
  221.     }  
        如此长的一个方法,好长好大,哈哈!不过,纵观全篇,无非三个Step就可搞定:

        1、Step1:Task及其运行时需要的辅助对象构造;

        2、Step2:Task运行;

        3、Step3:Task运行结果处理。

        对, 就这么简单!鉴于时间与篇幅问题,我们这里先讲下主要流程,细节方面的东西留待下节继续。

        下面,我们一个个Step来看,首先看下Step1:Task及其运行时需要的辅助对象构造,主要包括以下步骤:

        1.1、构造TaskMemoryManager任务内存管理器,即taskMemoryManager;

        1.2、记录反序列化开始时间;

        1.3、当前线程设置上下文类加载器;

        1.4、从SparkEnv中获取序列化器ser;

        1.5、execBackend更新状态TaskState.RUNNING;

        1.6、计算垃圾回收时间;

        1.7、调用Task的deserializeWithDependencies()方法,反序列化Task,得到Task运行需要的文件taskFiles、jar包taskFiles和Task二进制数据taskBytes;

        1.8、反序列化Task二进制数据taskBytes,得到task实例;

        1.9、设置Task的任务内存管理器;

        1.10、如果此时Task被kill,抛出异常,快速退出;

        

        接下来,是Step2:Task运行,主要流程如下:

        2.1、获取task开始时间;

        2.2、标志位threwException设置为true,标识Task真正执行过程中是否抛出异常;

        2.3、调用Task的run()方法,真正执行Task,并获得运行结果value,和累加器更新accumUpdates;

        2.4、标志位threwException设置为false;

        2.5、通过任务内存管理器taskMemoryManager清理所有的分配的内存;

        2.6、获取task完成时间;

        2.7、如果task被杀死,抛出TaskKilledException异常。

        

        最后一步,Step3:Task运行结果处理,大体流程如下:

        3.1、通过SparkEnv获取Task运行结果序列化器;

        3.2、获取结果序列化前的时间点;

        3.3、利用Task运行结果序列化器序列化Task运行结果value,得到valueBytes;

        3.4、获取结果序列化后的时间点;

        3.5、度量指标体系相关,暂不介绍;

        3.6、构造DirectTaskResult,同时包含Task运行结果valueBytes和累加器更新值accumulator updates;

        3.7、序列化DirectTaskResult,得到serializedDirectResult;

        3.8、获取Task运行结果大小;

        3.9、处理Task运行结果:

                 3.9.1、如果Task运行结果大小大于所有Task运行结果的最大大小,序列化IndirectTaskResult,IndirectTaskResult为存储在Worker上BlockManager中DirectTaskResult的一个引用;

                 3.9.2、如果 Task运行结果大小超过Akka除去需要保留的字节外最大大小,则将结果写入BlockManager,Task运行结果比较小的话,直接返回,通过消息传递;

                 3.9.3、Task运行结果比较小的话,直接返回,通过消息传递

        3.10、execBackend更新状态TaskState.FINISHED;

        最后,无论运行成功还是失败,将task从runningTasks中移除。


参考:http://blog.csdn.net/lipeng_bigdata/article/details/50726216

http://blog.csdn.net/anzhsoft/article/details/40238111



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值