* 1) 从TaskRunner开始,就是我们的Task运行的工作原理。然后一步步来剖析Task内部的工作原理。 ysj
* 2) 执行线程的run方法。
* a) 通过网络传输,将需要的文件、资源、jar包拷贝过来
* b) 将task中的数据反序列化
* c) 执行taks run方法 进入此方法
* 3) Task.scala 中的run方法
* a) 创建一个TaskContext,就是task的上下文 ,里面记录了task执行的一些全局性的数据;
* 比如:task重试了几次,spark属于哪个stage、task要处理的是哪个rdd的partition等
* b) 调用 runtask()方法 ;实际上执行的是子类(ShuffleMapTask和ResultMapTask的runtask()方法)。
* 一个ShuffleMapTask会将一个RDD的元素,切分为多个bucket;基于一个在ShuffleDependency中指定的partitioner,默认就是HashPartitioner
* 一个ShuffleMapTask的runTask方法有MapStatus返回值。
* @param execBackend
* @param taskDescription
class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
extends Runnable {
val taskId = taskDescription.taskId
val threadName = s"Executor task launch worker for task $taskId"
private val taskName =
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var reasonIfKilled: Option[String] = None
@volatile private var threadId: Long = -1
def getThreadId: Long = threadId
/** Whether this task has been finished. */
private var finished = false
def isFinished: Boolean = synchronized { finished }
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
reasonIfKilled = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread, reason)
* Set the finished flag to true and clear the current thread's interrupt status
private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized {
this.finished = true
// SPARK-14234 - Reset the interrupted status of the thread to avoid the
// ClosedByInterruptException during execBackend.statusUpdate which causes
// Executor to crash
// Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
// is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
// is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
override def run(): Unit = {
threadId = Thread.currentThread.getId
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
try {
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
// 将task中的数据反序列化
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties =
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException(killReason.get)
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
// Run the actual task and measure its runtime.
// 计算task开始时间
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
var threwException = true
val value = try {
val res =
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
} finally {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
if (releasedLocks.nonEmpty && !threwException) {
val errMsg =
s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
releasedLocks.mkString("[", ", ", "]")
if (conf.getBoolean("", false)) {
throw new SparkException(errMsg)
} else {
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
val taskFinish = System.currentTimeMillis() // task结束的时间
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
} else 0L
// If the task has been killed, let's fail it.
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, deserializes the RDD and function to be run.
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract's deserialization time to avoid double-counting
task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
// Note: accumulator updates must be collected after TaskMetrics is updated
val accumUpdates = task.collectAccumulatorUpdates()
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
val fetchFailedCls = classOf[FetchFailedException].getName
logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
s"failed, but the ${fetchFailedCls} was hidden by another " +
s"exception. Spark is handling this like a fetch failure and ignoring the " +
s"other exception: $t")
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
val accUpdates = => acc.toInfo(Some(acc.value), None))
val serializedTaskEndReason = {
try {
ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
} finally {
private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined