既然讨论任务调度那我们自然要谈谈究竟什么是任务(Task),作为任务调度系统的灵魂,任务(Task)是一个单独执行的单位,在Spark中有两种类型的任务(Task):ShuffleMapTask和ResultTask。在Spark中一个job会包含一个或多个stage,其中最后stage包含多个ResultTask,然而较早的stage则由多个ShuffleMapTask组成。一个ResultTask执行任务(Task)并将任务(Task)的输出结果回传到driver application;ShuffleMapTask会将任务(Task)的输出结果划分为不同的bucket。
* A unit of execution. We have two kinds of Task's in Spark:
* - [[org.apache.spark.scheduler.ShuffleMapTask]]
* - [[org.apache.spark.scheduler.ResultTask]]
* A Spark job consists of one or more stages. The very last stage in a job consists of multiple
* ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
* and sends the task output back to the driver application. A ShuffleMapTask executes the task
* and divides the task output to multiple buckets (based on the task's partitioner).
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable
class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
val ser = SparkEnv.get.closureSerializer.newInstance()
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
} catch {
case e: Exception => logError("Exception in error callback", e)
else {
freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask))
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
sparkProperties: Seq[(String, String)])
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
override def receiveWithLogging = {
case LaunchTask(data) =>
if (executor == null) {
logError("Received LaunchTask command but executor was null")
} else {
val ser = SparkEnv.get.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
def launchTask(
context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, taskName, serializedTask)
runningTasks.put(taskId, tr)
class TaskRunner(
execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
extends Runnable {
override def run() {
val startTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
val startGCTime = gcTime
try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
attemptedTask = Some(task)
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
val resultSer = SparkEnv.get.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
m.executorDeserializeTime = taskStart - startTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
m.resultSerializationTime = afterSerialization - beforeSerialization
val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
val (serializedResult, directSend) = {
if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
(ser.serialize(new IndirectTaskResult[Any](blockId)), false)
} else {
(serializedDirectResult, true)
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
if (directSend) {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
} else {
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
case t: Throwable => {
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
m.jvmGCTime = gcTime - startGCTime
val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
} finally {
// Release memory used by this thread for shuffles
// Release memory used by this thread for unrolling blocks
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
def runTask(context: TaskContext): T
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(partition, context))
} finally {
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) // 将rdd计算的结果写入memory或者disk
return writer.stop(success = true).get
} catch {
case e: Exception =>
if (writer != null) {
writer.stop(success = false)
throw e
} finally {
这两个task都需要按照拓扑顺序调用rdd的compute来完成对partition的计算,不同的是ShuffleMapTask需要shuffle write,以供child stage读取shuffle的结果。 对于这两个task都用到的taskBinary,即为在DAGScheduler的submitMissingTasks序列化的task的广播变量取得的。