深入Spark内核:任务调度(4)-Task

        既然讨论任务调度那我们自然要谈谈究竟什么是任务(Task),作为任务调度系统的灵魂,任务(Task)是一个单独执行的单位,在Spark中有两种类型的任务(Task):ShuffleMapTask和ResultTask。在Spark中一个job会包含一个或多个stage,其中最后stage包含多个ResultTask,然而较早的stage则由多个ShuffleMapTask组成。一个ResultTask执行任务(Task)并将任务(Task)的输出结果回传到driver application;ShuffleMapTask会将任务(Task)的输出结果划分为不同的bucket。

/**
 * A unit of execution. We have two kinds of Task's in Spark:
 * - [[org.apache.spark.scheduler.ShuffleMapTask]]
 * - [[org.apache.spark.scheduler.ResultTask]]
 *
 * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
 * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
 * and sends the task output back to the driver application. A ShuffleMapTask executes the task
 * and divides the task output to multiple buckets (based on the task's partitioner).
 *
 * @param stageId id of the stage this task belongs to
 * @param partitionId index of the number in the RDD
 */
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable 

       接下来我们讨论任务调度策略中是如何运行任务(Task),我们知道CoarseGrainedSchedulerBackend会发送LaunchTask消息,然后可以清楚看出CoarseGrainedExecutorBackend的LaunchTask是消息的接收方,最后excutor会调用launchTask方法

private[spark]
class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
  extends SchedulerBackend with Logging
{
    ...

    // Launch tasks returned by a set of resource offers
    def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val serializedTask = ser.serialize(task)
        if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
          val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
          scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
                "spark.akka.frameSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
                AkkaUtils.reservedSizeBytes)
              taskSet.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
          executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask))
        }
      }
    }
    
    ...
 }
private[spark] class CoarseGrainedExecutorBackend(
    driverUrl: String,
    executorId: String,
    hostPort: String,
    cores: Int,
    sparkProperties: Seq[(String, String)])
  extends Actor with ActorLogReceive with ExecutorBackend with Logging {
    ...
    
  override def receiveWithLogging = {
    ...

    case LaunchTask(data) =>
      if (executor == null) {
        logError("Received LaunchTask command but executor was null")
        System.exit(1)
      } else {
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val taskDesc = ser.deserialize[TaskDescription](data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
      }

    ...
  }

  ...
}

       在Executor的launchTask方法内部会创建TaskRunner的实例对象并启动一个新的线程或者从已有的线程池中激活一个线程来运行任务(Task)

  def launchTask(
      context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
    val tr = new TaskRunner(context, taskId, taskName, serializedTask)
    runningTasks.put(taskId, tr)
    threadPool.execute(tr)
  }
       TaskRunner类必须实现Runnable接口的run方法,会从序列化的task中反序列化得到task,进行必要的初始化工作并检测当前任务(Task)是否被撤销。如果任务(Task)没有被撤销,那么task实例对象调用run方法启动任务(Task)。

  class TaskRunner(
      execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
    extends Runnable {
    ...
    
    override def run() {
      val startTime = System.currentTimeMillis()
      SparkEnv.set(env)
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = SparkEnv.get.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
      val startGCTime = gcTime

      try {
        SparkEnv.set(env)
        Accumulators.clear()
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
        updateDependencies(taskFiles, taskJars)
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        attemptedTask = Some(task)
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        taskStart = System.currentTimeMillis()
        val value = task.run(taskId.toInt)
        val taskFinish = System.currentTimeMillis()

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }

        val resultSer = SparkEnv.get.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        for (m <- task.metrics) {
          m.executorDeserializeTime = taskStart - startTime
          m.executorRunTime = taskFinish - taskStart
          m.jvmGCTime = gcTime - startGCTime
          m.resultSerializationTime = afterSerialization - beforeSerialization
        }

        val accumUpdates = Accumulators.values

        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val (serializedResult, directSend) = {
          if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
          } else {
            (serializedDirectResult, true)
          }
        }

        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

        if (directSend) {
          logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
        } else {
          logInfo(
            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
        }
      } catch {
        case ffe: FetchFailedException => {
          val reason = ffe.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
        }

        case _: TaskKilledException | _: InterruptedException if task.killed => {
          logInfo(s"Executor killed $taskName (TID $taskId)")
          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
        }

        case t: Throwable => {
          // Attempt to exit cleanly by informing the driver of our failure.
          // If anything goes wrong (or this was a fatal exception), we will delegate to
          // the default uncaught exception handler, which will terminate the Executor.
          logError(s"Exception in $taskName (TID $taskId)", t)

          val serviceTime = System.currentTimeMillis() - taskStart
          val metrics = attemptedTask.flatMap(t => t.metrics)
          for (m <- metrics) {
            m.executorRunTime = serviceTime
            m.jvmGCTime = gcTime - startGCTime
          }
          val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

          // Don't forcibly exit unless the exception was inherently fatal, to avoid
          // stopping other tasks unnecessarily.
          if (Utils.isFatalError(t)) {
            ExecutorUncaughtExceptionHandler.uncaughtException(t)
          }
        }
      } finally {
        // Release memory used by this thread for shuffles
        env.shuffleMemoryManager.releaseMemoryForThisThread()
        // Release memory used by this thread for unrolling blocks
        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
        runningTasks.remove(taskId)
      }
    }
  }
       在Task类的run方法中会创建TaskContext的实例对象作为参数传入runTask方法中,我们可以看出runTask是一个抽象方法,而ShuffleMapTask和ResultTask则分别实现了不同的runTask

  final def run(attemptId: Long): T = {
    context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
    context.taskMetrics.hostname = Utils.localHostName()
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    runTask(context)
  }
  def runTask(context: TaskContext): T
       ResultTask的runTask即顺序调用rdd的compute,通过rdd的拓扑顺序依次对partition进行计算。

override def runTask(context: TaskContext): U = {
  // Deserialize the RDD and the func using the broadcast variables.
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

  metrics = Some(context.taskMetrics)
  try {
    func(context, rdd.iterator(partition, context))
  } finally {
    context.markTaskCompleted()
  }
}
       ShuffleMapTask的runTask则是写shuffle的结果。

override def runTask(context: TaskContext): MapStatus = {
  // Deserialize the RDD using the broadcast variable.
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    //此处的taskBinary即为在org.apache.spark.scheduler.DAGScheduler#submitMissingTasks序列化的task的广播变量取得的

  metrics = Some(context.taskMetrics)
  var writer: ShuffleWriter[Any, Any] = null
  try {
    val manager = SparkEnv.get.shuffleManager
    writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) // 将rdd计算的结果写入memory或者disk
    return writer.stop(success = true).get
  } catch {
    case e: Exception =>
      if (writer != null) {
        writer.stop(success = false)
      }
      throw e
  } finally {
    context.markTaskCompleted()
  }
}
        这两个task都需要按照拓扑顺序调用rdd的compute来完成对partition的计算,不同的是ShuffleMapTask需要shuffle write,以供child stage读取shuffle的结果。 对于这两个task都用到的taskBinary,即为在DAGScheduler的submitMissingTasks序列化的task的广播变量取得的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值