Spark源码阅读(三) --- Spark的任务调度机制

目录

一、任务调度机制的概述

从层级上来看

二、任务调度机制的源码

1、Application -> Job

2、Job -> Stage

1、withScope方法

2、clean方法

3、runJob

4、submitJob方法与event的处理

5、handleJobSubmitted()方法与stage的划分

3、Stage -> Task

1、sumbitStage()方法

2、taskScheduler.submitTasks()方法

3、SchedulerBackend的start与reviveOffers()方法

4、makeOffer()方法与资源分配

后话


一、任务调度机制的概述

当一个Spark Application被提交之后,Spark启动driver,driver端主要承担着解析代码的工作,并对Application中的运行逻辑进行划分

从层级上来看

Application -> Job:Application可以以Action算子划分为多个Job

Job -> Stage:Job会进入到DAGScheduler中,DAGScheduler会构建为基于Job的DAG(逻辑上的),并将DAG根据宽依赖和窄依赖划分为多个Stage

Stage -> Task:Stage有多个Task组成,它会被封装到一个TaskSet对象内,TaskSet会进入到TaskScheduler中,并根据partition数来划分出一个个Task,一般一个Task对应一个partition,最终Task会按照指定的调度被分发到已经启动好的Executor去执行

整体调度流程如下图所示

Spark任务调度概览

二、任务调度机制的源码

1、Application -> Job

在本文第一部分我们说过一个Application根据Action算子来划分为一个个Job

以如下代码为例

object ApplicationSubmit {
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().master("local").getOrCreate()
    val rdd = sparkSession.sparkContext.makeRDD(Seq(1,2,3,4,5,6))
    rdd.foreach(print)
    println(rdd.sum())
  }
}

运行这段code,我们可以查看console输出的日志

注意图中红框圈出的文字,我们可以发现SparkContext先运行了Job 0:foreach(),直至Job 0结束,才运行Job 1:Sum()

sum和foreach都是Action算子,通过这个log信息我们进一步验证了一个Application以Action算子为界进行对Job的划分

2、Job -> Stage

1、withScope方法

根据上面的代码块,我们选取其中一个Action算子进行深入研究,以foreach算子为例

查看foreach()的source code,是如下的代码

/**
   * Applies a function f to all elements of this RDD.
   */
  def foreach(f: T => Unit): Unit = withScope {
    val cleanF = sc.clean(f)
    sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
  }

可以看到这是个高阶函数,并且它方法的主体最外层套着一个withScope方法

withScope方法是RDDOperationScope中的一个方法,它的主要目的是为了做DAG可视化的,我们可以在Spark UI中查看到一个个stage可视化的执行流程

我们可以看到foreach方法方法主体里面有一个clean方法和runJob方法

2、clean方法

clean方法在spark源码中大量的出现,点击进去查看它的源码

/**
   * Clean a closure to make it ready to be serialized and sent to tasks
   * (removes unreferenced variables in $outer's, updates REPL variables)
   * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
   * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
   * if not.
   *
   * @param f the closure to clean
   * @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
   * @throws SparkException if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
   *   serializable
   * @return the cleaned closure
   */
  private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
    ClosureCleaner.clean(f, checkSerializable)
    f
  }

发现底层调用的是ClosureCleaner.clean()方法,并且官方注释是这么说的

Clean a closure to make it ready to be serialized and sent to tasks

翻译过来就是

清除闭包以使该闭包能够序列化并发送到各task

我们继续按深入去查看clean方法

/**
   * Helper method to clean the given closure in place.
   *
   * The mechanism is to traverse the hierarchy of enclosing closures and null out any
   * references along the way that are not actually used by the starting closure, but are
   * nevertheless included in the compiled anonymous classes. Note that it is unsafe to
   * simply mutate the enclosing closures in place, as other code paths may depend on them.
   * Instead, we clone each enclosing closure and set the parent pointers accordingly.
   *
   * .......该注释过长这里只展示一部分
   */
private def clean(
      func: AnyRef,
      checkSerializable: Boolean,
      cleanTransitively: Boolean,
      accessedFields: Map[Class[_], Set[String]]): Unit = {

    // most likely to be the case with 2.12, 2.13
    // so we check first
    // non LMF-closures should be less frequent from now on
    val lambdaFunc = getSerializedLambda(func)

    if (!isClosure(func.getClass) && lambdaFunc.isEmpty) {
      logDebug(s"Expected a closure; got ${func.getClass.getName}")
      return
    }
    //......
    //......
    //......该方法过长这里只展示一部分
    if (checkSerializable) {
      ensureSerializable(func)
    }
}

最终停留在ClosureCleaner的这个clean方法

首先看注释,注释的意思是说清理用不到引用,并且确保该闭包一定能够序列化

查看该方法体,内部通过反射获取了我们传递进来的func内部的使用的class、field等等,并一一进行check,如果没有任何问题,将会在该方法的末尾返回ensureSerializable(func),并确定该func可以正常执行。

那为什么要进行这个操作呢?因为spark是支持分布式运行的,那么有分布式必然会涉及到网络IO,那么有IO必然需要序列化,所以需要进行这样的check操作。

举个例子

object ApplicationSubmit {
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().master("local").getOrCreate()
    val rdd = sparkSession.sparkContext.makeRDD(Seq(new TestClass("1"),new TestClass("2")))
    rdd.foreach(println)
  }

  class TestClass(name: String) {
    val className: String = name
  }

}

这段代码里创建了一各包含未序列化Class的RDD,运行后果然报出了一堆报错信息

Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Failed to serialize task 0, not attempting to retry it. Exception during serialization: java.io.NotSerializableException: sourceCodeParse.ApplicationSubmit$TestClass
Serialization stack:
	- object not serializable (class: sourceCodeParse.ApplicationSubmit$TestClass, value: sourceCodeParse.ApplicationSubmit$TestClass@21dc04fd)
	- element of array (index: 0)
	- array (class [LsourceCodeParse.ApplicationSubmit$TestClass;, size 2)
	- field (class: scala.collection.mutable.WrappedArray$ofRef, name: array, type: class [Ljava.lang.Object;)
	- object (class scala.collection.mutable.WrappedArray$ofRef, WrappedArray(sourceCodeParse.ApplicationSubmit$TestClass@21dc04fd, sourceCodeParse.ApplicationSubmit$TestClass@7b65a1cd))
	- writeObject data (class: org.apache.spark.rdd.ParallelCollectionPartition)
	- object (class org.apache.spark.rdd.ParallelCollectionPartition, org.apache.spark.rdd.ParallelCollectionPartition@691)
	- field (class: org.apache.spark.scheduler.ResultTask, name: partition, type: interface org.apache.spark.Partition)
	- object (class org.apache.spark.scheduler.ResultTask, ResultTask(0, 0))

我们将其改成case class或者令其extends Serializable

object ApplicationSubmit {
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().master("local").getOrCreate()
    val rdd = sparkSession.sparkContext.makeRDD(Seq(new TestClass("1"),new TestClass("2")))
    rdd.foreach(print)
  }

  case class TestClass(name: String) {
    val className: String = name
  }

}

结果成功运行了

3、runJob

clean()方法运行过后,runJob()将被调用,在spark源码中,有许多runJob()方法,它们最后都会调用如下方法

 /**
   * Run a function on a given set of partitions in an RDD and pass the results to the given
   * handler function. This is the main entry point for all actions in Spark.
   *
   * @param rdd target RDD to run tasks on
   * @param func a function to run on each partition of the RDD
   * @param partitions set of partitions to run on; some jobs may not want to compute on all
   * partitions of the target RDD, e.g. for operations like `first()`
   * @param resultHandler callback to pass each result to
   */
  def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage", false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }

首先是getCallSite,callSite在源码里是这样描述的

CallSite represents a place in user code. It can have a short and a long form.

翻译过来就是

CallSite代表了用户代码里的一个位置,它有长格式和短格式

这时我们查看getCallSite的源码,其方法大致的功能是取当前线程的堆栈信息,将符合规则的放入栈顶,并最终返回,其目的就是为了打印出log信息,便于查看job的进度。

接下来是

dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)

看到dagScheduler,Job要开始被dagSchduler调用划分为一个个stage了。

注意dagScheduler在sparkContext创建时就创建了,其连续调用的构造方法分别是

_dagScheduler = new DAGScheduler(this)
def this(sc: SparkContext) = this(sc, sc.taskScheduler)
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
    this(
      sc,
      taskScheduler,
      sc.listenerBus,
      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
      sc.env.blockManager.master,
      sc.env)
  }

可以注意到taskScheduler也被一并传进来了,代表taskScheduler也被创建了,并且先于dagScheduler创建好,其构造方法是

val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode)
_schedulerBackend = sched
_taskScheduler = ts

schedulerBackend是一个trait,它有CoarseGrainedSchedulerBackend,StandaloneSchedulerBackend和LocalSchedulerBackend三种不同的实现,它们被sparkContext用于管理不同deploy mode下的资源。

4、submitJob方法与event的处理

回到dagScheduler.runJob()方法,我们发现内部又调用了submitJob()的这样一个方法

val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)

其内部源码是

def submitJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): JobWaiter[U] = {
    // Check to make sure we are not launching a task on a partition that does not exist.
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
      throw new IllegalArgumentException(
        "Attempting to access a non-existent partition: " + p + ". " +
          "Total number of partitions: " + maxPartitions)
    }

    val jobId = nextJobId.getAndIncrement()
    if (partitions.size == 0) {
      // Return immediately if the job is running 0 tasks
      return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

    assert(partitions.size > 0)
    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
    eventProcessLoop.post(JobSubmitted(
      jobId, rdd, func2, partitions.toArray, callSite, waiter,
      SerializationUtils.clone(properties)))
    waiter
  }

可以看到其内部先生成了一个JobId,随后又创建了一个JobWaiter去监听Job的执行状态,其标记状态的有成功和失败,只有当所有task都成功其状态才会标记为成功,否则将标记为失败

接下来是eventProcessLoop类,它实际上是一个DAGSchedulerEventProcessLoop类,DAGSchedulerEventProcessLoop继承了EventLoop类,在其内部对onReceive方法进行了重写,用以调用doOnReceive方法,以下是doOnReceive的源码

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
    //......省略大部分

    case ResubmitFailedStages =>
      dagScheduler.resubmitFailedStages()
  }

可以看到其内部使用了case进行event匹配,最终调用了dagScheduler.handleJobSubmitted这个方法进行Job的处理

那么onReceive方法又是如何触发的呢,回到eventProcessLoop处,其调用了post()方法,点进去post()方法内部

/**
   * Put the event into the event queue. The event thread will process it later.
   */
  def post(event: E): Unit = {
    eventQueue.put(event)
  }

eventQueue是一个LinkedBlockingDeque,将我们传入的event放入到这个事件队列里,而post方法本身被EventLoop所持有,EventLoop内部有一个eventThread,是一个Thread类,其持有一个run()方法,run()方法内部为

override def run(): Unit = {
      try {
        while (!stopped.get) {
          val event = eventQueue.take()
          try {
            onReceive(event)
          } catch {
            case NonFatal(e) =>
              try {
                onError(e)
              } catch {
                case NonFatal(e) => logError("Unexpected error in " + name, e)
              }
          }
        }
      } catch {
        case ie: InterruptedException => // exit even if eventQueue is not empty
        case NonFatal(e) => logError("Unexpected error in " + name, e)
      }
    }

我们可以看到eventQueue取出了其内部所存在的event,并使用onReceive方法调用,而DAGSchedulerEventProcessLoop就是eventProcessLoop类,其本身继承了EventLoop类,而eventProcessLoop类会在DagScheduler类体中的结尾处调用eventProcessLoop.start()方法(本质上就是个DagScheduler构造函数内部调用,所以eventThread在DagScheduler创建时就启动了)

start()方法会调用eventThread中的start()方法,其内部调用run()方法,run()方法内部调用抽像方法onReceive(),进而反复的处理submit上去的event(JobSubmit事件)

5、handleJobSubmitted()方法与stage的划分

接下来我们回到doOnReceive()方法处,我们刚刚提交的是JobSubmitted类,则将触发这段代码

case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

进入handleJobSubmitted()这个方法

private[scheduler] def handleJobSubmitted(jobId: Int,
      finalRDD: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      callSite: CallSite,
      listener: JobListener,
      properties: Properties) {
    var finalStage: ResultStage = null
    try {
      // New stage creation may throw an exception if, for example, jobs are run on a
      // HadoopRDD whose underlying HDFS files have been deleted.
      finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
    } catch {
      case e: BarrierJobSlotsNumberCheckFailed =>
        logWarning(s"The job $jobId requires to run a barrier stage that requires more slots " +
          "than the total number of slots in the cluster currently.")
        // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically.
        val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId,
          new BiFunction[Int, Int, Int] {
            override def apply(key: Int, value: Int): Int = value + 1
          })
        if (numCheckFailures <= maxFailureNumTasksCheck) {
          messageScheduler.schedule(
            new Runnable {
              override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func,
                partitions, callSite, listener, properties))
            },
            timeIntervalNumTasksCheck,
            TimeUnit.SECONDS
          )
          return
        } else {
          // Job failed, clear internal data.
          barrierJobIdToNumTasksCheckFailures.remove(jobId)
          listener.jobFailed(e)
          return
        }

      case e: Exception =>
        logWarning("Creating new stage failed due to exception - job: " + jobId, e)
        listener.jobFailed(e)
        return
    }
    // Job submitted, clear internal data.
    barrierJobIdToNumTasksCheckFailures.remove(jobId)

    val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
    clearCacheLocs()
    logInfo("Got job %s (%s) with %d output partitions".format(
      job.jobId, callSite.shortForm, partitions.length))
    logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
    logInfo("Parents of final stage: " + finalStage.parents)
    logInfo("Missing parents: " + getMissingParentStages(finalStage))

    val jobSubmissionTime = clock.getTimeMillis()
    jobIdToActiveJob(jobId) = job
    activeJobs += job
    finalStage.setActiveJob(job)
    val stageIds = jobIdToStageIds(jobId).toArray
    val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
    listenerBus.post(
      SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
    submitStage(finalStage)
  }

一个job实际上就只能被分为ResultStage和ShuffleMapStage两种类型。ResultStage是执行一个action操作的最后的一个stage,一般用来做一些生成结果的操作,如输出结果,写文件之类的。而ShuffleMapStage会生成用于shuffle的数据,即中间数据。可以看到一个finalStage通过createResultStage()方法被创建,查看该方法源码如下

/**
   * Create a ResultStage associated with the provided jobId.
   */
  private def createResultStage(
      rdd: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      jobId: Int,
      callSite: CallSite): ResultStage = {
    checkBarrierStageWithDynamicAllocation(rdd)
    checkBarrierStageWithNumSlots(rdd)
    checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
    val parents = getOrCreateParentStages(rdd, jobId)
    val id = nextStageId.getAndIncrement()
    val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
    stageIdToStage(id) = stage
    updateJobIdStageIdMaps(jobId, stage)
    stage
  }

看到内部先调用了getOrCreateParentStages()方法获取了parents,进入该方法发现底层调用了这样一段代码,进一步证实了stage确实根据shuffle操作进行划分的

 getShuffleDependencies(rdd).map { shuffleDep =>
      getOrCreateShuffleMapStage(shuffleDep, firstJobId)
    }.toList

getOrCreateParentStages()其实是划分job为一个个stage的核心方法,其处理同一个stage内rdd的顺序是完全随机,内部调用了getOrCreateShuffleMapStage()方法,该方法主要部分如下


    shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
      case Some(stage) =>
        stage

      case None =>
        getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
          if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
            createShuffleMapStage(dep, firstJobId)
          }
        }
        createShuffleMapStage(shuffleDep, firstJobId)
    }
  

可以看到内部通过调用createShuffleMapStage()方法根据一个个shuffle创建并返回shuffleMapStage类,所以说实质上stage create方法的实际调用完成顺序是这样的

createShuffleMapStage(shuffleDep, firstJobId)
getOrCreateShuffleMapStage(shuffleDep, firstJobId)
getOrCreateParentStages(rdd, jobId)
createResultStage(finalRDD, func, partitions, jobId, callSite)

在调用createResultStage()方法时就已经将他之前依赖的所有parent Stage创建完成了,包括我们在createShuffleMapStage()方法内也可以看到对parents的获取

def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
    val rdd = shuffleDep.rdd
    checkBarrierStageWithDynamicAllocation(rdd)
    checkBarrierStageWithNumSlots(rdd)
    checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions)
    val numTasks = rdd.partitions.length
    val parents = getOrCreateParentStages(rdd, jobId)
    val id = nextStageId.getAndIncrement()
    //......
}

在各stage划分完毕之后,回到handleJobSubmitted()方法,可以看到一个ActiveJob类被创建

val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)

到底这里一个Job才算是真正被生成了,注意在某些情况下Job会本地运行

(1)spark.localExecution.enabled设置为true

(2)用户指定本地运行

(3)finalStage没有parent Stage

(4)仅有一个partition

在(3)、(4)的情况下本地运行是为了保证任务的快速执行

listnerBus是一个trait,可以接收事件并将事件提交给对应的事件监听器

listenerBus.post(
      SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))

3、Stage -> Task

1、sumbitStage()方法

随后触发了submitStage()方法对stage进行提交

submitStage(finalStage)

点击进入submitStage内,源码如下

/** Submits stage, but first recursively submits any missing parents. */
  private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        if (missing.isEmpty) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          submitMissingTasks(stage, jobId.get)
        } else {
          for (parent <- missing) {
            submitStage(parent)
          }
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id, None)
    }
  }

内部调用了getMissingParentStages()方法和递归调用了submitStage()方法,其主要目的是递归调用那些未提交的parent stage,只有所有parent stage调用并计算完毕后,才轮到该stage提交并计算。

当missing这个List为空时,代表没有未提交和计算的parent stage了,可以执行submitMissingTasks提交当前得stage了

在submitMissingTasks内部中,获取了需要计算的当前stage的分区partitionsToCompute,并将该stage加入到runningStages的HashSet中,随后是如下代码

stage match {
      case s: ShuffleMapStage =>
        outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
      case s: ResultStage =>
        outputCommitCoordinator.stageStart(
          stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
    }

outPutCommitCoordinator实际上是一个输出提交协调器,source code中注释对它的描述是

/**
 * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
 */

有权决定能否将tasks输出提交到HDFS。使用“第一个提交者获胜”的策略(FIFO)

outPutCommitCoordinator在Driver端和Executors端都是有实例的,同时在driver端上还注册了一个OutputCommitCoordinatorEndpoint,该class继承了RpcEndpoint,Executors上的OutputCommitCoordinator都会通过OutputCommitCoordinatorEndpoint的RpcEndpointRefDriver上的OutputCommitCoordinator通信,并向其询问是否能够将输出提交到HDFS

val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
      case NonFatal(e) =>
        stage.makeNewStageAttempt(partitionsToCompute.size)
        listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage
        return
    }

getPreferredLocs是用来返回RDD分区中的位置信息的,里面调用了getPreferredLocsInternal方法,getPreferredLocsInternal方法体如下

注:

Spark 有五种不同的 Locality Level,分别是
PROCESS_LOCAL:数据和task在同一个Executor中,中间无需任何数据传输,效率最高,如cache后的数据一般就是这种Locality level。
NODE_LOCAL:数据和task在同一个worker上,但是他们在不同的Executor,或者在HDFS上,这个效率稍慢一些,需要从文件读或者进程间传输。
NO_PREF: 数据从哪里访问都一样快,不需要位置优先,一般是读取数据库里的数据会产生这 种Locality Level。
RACK_LOCAL: 数据在同一机架的不同节点上,需要通过网络传输数据及文件 IO。
ANY: 数据在非同一机架的网络上,速度最慢,出现跨集群传输数据才会有这种情况,一般不会出现。

该方法内会依次

private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // If the partition has already been visited, no need to re-visit.
    // This avoids exponential path exploration.  SPARK-695
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    }
    // If the partition is cached, return the cache locations
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }
    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }

    // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
    // that has any placement preferences. Ideally we would choose based on transfer sizes,
    // but this will do for now.
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }

      case _ =>
    }

    Nil
  }

判断该rdd的partition是否被访问,是否被cache,并最终返回其一个最佳的locality information。注意到里面调用了TaskLocation,TaskLocation是一个sealed trait,有三个子类实现了它,分别是

case class ExecutorCacheTaskLocation(override val host: String, executorId: String)
  extends TaskLocation {
  override def toString: String = s"${TaskLocation.executorLocationTag}${host}_$executorId"
}

/**
 * A location on a host.
 */
private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
  override def toString: String = host
}

/**
 * A location on a host that is cached by HDFS.
 */
private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation {
  override def toString: String = TaskLocation.inMemoryLocationTag + host
}

ExecutorCacheTaskLocation代表数据已经被cache了,正是PROCESS_LOCAL情况,HostTaskLocation和HDFSCacheTaskLocation分别代表数据存储在同worker上和存储在HDFS上,正是NODE_LOCAL的情况。task的locality level与partition是一致的。

接下来会通过根据不同的task类型,通过broadCast分发不同的数据到executor上

For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).

对于ShuffleMapTask,它序列化和分发rdd以及shuffleDep

For ResultTask, serialize and broadcast (rdd, func).

对于ResultTask,它序列化和分发rdd以及func

RDDCheckpointData.synchronized {
        taskBinaryBytes = stage match {
          case stage: ShuffleMapStage =>
            JavaUtils.bufferToArray(
              closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
          case stage: ResultStage =>
            JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
        }

        partitions = stage.rdd.partitions
      }

      taskBinary = sc.broadcast(taskBinaryBytes)

之后将根据不同的stage类型生成不同的task,其对应的源码如下

val tasks: Seq[Task[_]] = try {
      val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
      stage match {
        case stage: ShuffleMapStage =>
          stage.pendingPartitions.clear()
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = partitions(id)
            stage.pendingPartitions += id
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
              Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
          }

        case stage: ResultStage =>
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, id, properties, serializedTaskMetrics,
              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
              stage.rdd.isBarrier())
          }
      }
    } catch {
      case NonFatal(e) =>
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage
        return
    }

最后会执行taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties))方法,stage被封装成一个个TaskSet通过submitTasks提交,至此DagScheduler的工作全部完成,进入taskScheduler的工作。

2、taskScheduler.submitTasks()方法

之前提到过taskScheduler是一个trait,其实现类是TaskSchedulerImpl,该类在sparkContext创建时通过createTaskScheduler方法创建好了,点进去查看可以看到有着数量众多的

val scheduler = new TaskSchedulerImpl()

证明askScheduler的实例类确实是TaskSchedulerImpl,回到submitTasks方法,源码如下

 override def submitTasks(taskSet: TaskSet) {
    val tasks = taskSet.tasks
    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
    this.synchronized {
      val manager = createTaskSetManager(taskSet, maxTaskFailures)
      val stage = taskSet.stageId
      val stageTaskSets =
        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])

      stageTaskSets.foreach { case (_, ts) =>
        ts.isZombie = true
      }
      stageTaskSets(taskSet.stageAttemptId) = manager
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

      if (!isLocal && !hasReceivedTask) {
        starvationTimer.scheduleAtFixedRate(new TimerTask() {
          override def run() {
            if (!hasLaunchedTask) {
              logWarning("Initial job has not accepted any resources; " +
                "check your cluster UI to ensure that workers are registered " +
                "and have sufficient resources")
            } else {
              this.cancel()
            }
          }
        }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
      }
      hasReceivedTask = true
    }
    backend.reviveOffers()
  }

首先创建一个TaskSetManager,TaskSetManager会根据数据的locality-aware(其locality-aware就是之前partition所获取prefered locality信息)为Task分配计算资源,监控Task的执行状态,并在失败时采取重试策略,或者推迟执行。该类主要的接口是resourceOffer,该方法会询问TaskSet,否想要在一个节点上运行一个task,并在task状态变更时通知TaskSet。

taskSetsByStageIdAndAttempt底层是一个HashMap,它会先通过stageId获取一个名为stageTaskSets的HashMap<Int, TaskSetManager>,stageAttemptId则是自动生成的一个stage尝试id,个人理解这里是因为关于该stage前面的尝试都有问题,所以要将前面所有的TaskSetManager的状态均设为zombie,随后在该段代码内

 stageTaskSets(taskSet.stageAttemptId) = manager

重新添加新的尝试id stageAttemptId和该stage的TaskSetManager。

随后将TaskSetManager添加到schedulableBuilder内,这里SchedulableBuilder实质上是一个trait,他有两个实现类,分别是FIFOSchedulableBuilder(先进先出)和FairSchedulableBuilder(公平调度),默认是FIFO,可以通过spark.scheduler.mode模式进行设置。该类是在SparkContext执行TaskSchedulerImpl的initialize()方法进行创建的。其source code如下

def initialize(backend: SchedulerBackend) {
    this.backend = backend
    schedulableBuilder = {
      schedulingMode match {
        case SchedulingMode.FIFO =>
          new FIFOSchedulableBuilder(rootPool)
        case SchedulingMode.FAIR =>
          new FairSchedulableBuilder(rootPool, conf)
        case _ =>
          throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " +
          s"$schedulingMode")
      }
    }
    schedulableBuilder.buildPools()
  }

starvationTimer.scheduleAtFixedRate是一个防止task没资源进入长期饥饿的定时器,该task提交后这个定时器就会被cancel掉。

最后调用了backend.reviveOffers()方法。

3、SchedulerBackend的start与reviveOffers()方法

我们知道SchedulerBackend是一个trait,有三种实现类,用以管理不同deploy mode下的资源的,并且在SparkContext初始化时就已经创建好。

则我们以CoarseGrainedSchedulerBackend的reviveOffers为例,看到其override的方法为

override def reviveOffers() {
    driverEndpoint.send(ReviveOffers)
  }

可以看到里面是调用了driverEndpoint的send()方法,driverEndpoint实际上是一个RpcEndpointRef类,该类只有一个实现类就是NettyRpcEndpointRef,主要功能是用来通信的,用于driver端发送信息为各TaskSet分配执行所需的资源。

接下来从send()方法深入进去,发现什么都没有了,而且我们也并不知道diverEndpoint实在何时创建的,并且这个send的发信对象是谁也不清楚。那么TaskScheduler又是如何提交Task的呢?
回到TaskSchedulerImpl class处,会发现上面的注释中有这么一句话

/**
 * Clients should first call initialize() and start(), then submit task sets through the
 * submitTasks method.
 */

大概意思就是initialize() and start()会被先调用,随后才会调用submitTasks提交task,initialize方法我们知道在sparkContext创建的时候就被调用用来创建TaskScheduler了,那么start()方法其实就是在随后被调用的,我们可以看到sparkContext内有这么一段代码

// Create and start the scheduler
    val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode)
    _schedulerBackend = sched
    _taskScheduler = ts
    _dagScheduler = new DAGScheduler(this)
    _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)

    _taskScheduler.start()

在TaskSchedulerImpl内override的start方法为

override def start() {
    backend.start()

    if (!isLocal && conf.getBoolean("spark.speculation", false)) {
      logInfo("Starting speculative execution thread")
      speculationScheduler.scheduleWithFixedDelay(new Runnable {
        override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
          checkSpeculatableTasks()
        }
      }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
    }
  }

可以看到内部调用了backend的start()方法,我们之前说过SchedulerBackend一个trait,内部有三个实现类,用以管理不同deploy mode下的资源调度,进入CoarseGrainedSchedulerBackend的start方法内部,代码如下

verride def start() {
    val properties = new ArrayBuffer[(String, String)]
    for ((key, value) <- scheduler.sc.conf.getAll) {
      if (key.startsWith("spark.")) {
        properties += ((key, value))
      }
    }

    // TODO (prashant) send conf instead of properties
    driverEndpoint = createDriverEndpointRef(properties)
  }

可以看到再获取完spark相关的conf之后,driverEndpoint通过createDriverEndpointRef方法创建,

createDriverEndpointRef方法内部为

protected def createDriverEndpointRef(
      properties: ArrayBuffer[(String, String)]): RpcEndpointRef = {
    rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
  }

可以看到里面有这样一段代码,rpcEnv.setupEndpoint((ENDPOINT_NAME, createDriverEndpoint(properties)),而createDriverEndpoint(properties)方法内部为

protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
    new DriverEndpoint(rpcEnv, properties)
  }

很明显创建出了一个DriverEndpoint类,并且该类被设为rpcEnv的Endpoint,回到banked.reviveOffers方法,我们终于明白了这个send方法的发信对象正是被设为Endpoint的DriverEndpoint,并且发送的是一个ReviveOffers message

总结下来就是CoarseGrainedSchedulerBackend所持有的driverEndpoint本质是一个RpcEndpointRef,其底层只有一个继承实现类就是NettyRpcEndpointRef,它在初始化的时候设置好了其发信的接收节点driverEndpoint,用以给driverEndpoint发信。

那么driverEndpoint底层又是什么呢,我们在CoarseGrainedSchedulerBackend中找到定义好的driverEndpoint

class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
    extends ThreadSafeRpcEndpoint with Logging {
/*..*/
}

可以看到其继承了ThreadSafeRpcEndpoint  trait,它又继承了RpcEndpoint trait,这一个线程安全的RpcEndpoint。我们可以看到RpcEndpoint trait内有许多方法

private[spark] trait RpcEndpoint {

  /**
   * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to.
   */
  val rpcEnv: RpcEnv

  /**
   * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
   * called. And `self` will become `null` when `onStop` is called.
   *
   * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
   * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
   */
  final def self: RpcEndpointRef = {
    require(rpcEnv != null, "rpcEnv has not been initialized")
    rpcEnv.endpointRef(this)
  }

  /**
   * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a
   * unmatched message, `SparkException` will be thrown and sent to `onError`.
   */
  def receive: PartialFunction[Any, Unit] = {
    case _ => throw new SparkException(self + " does not implement 'receive'")
  }

  /**
   * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message,
   * `SparkException` will be thrown and sent to `onError`.
   */
  def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
  }

  /**
   * Invoked when any exception is thrown during handling messages.
   */
  def onError(cause: Throwable): Unit = {
    // By default, throw e and let RpcEnv handle it
    throw cause
  }

  /**
   * Invoked when `remoteAddress` is connected to the current node.
   */
  def onConnected(remoteAddress: RpcAddress): Unit = {
    // By default, do nothing.
  }

  /**
   * Invoked when `remoteAddress` is lost.
   */
  def onDisconnected(remoteAddress: RpcAddress): Unit = {
    // By default, do nothing.
  }

  /**
   * Invoked when some network error happens in the connection between the current node and
   * `remoteAddress`.
   */
  def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
    // By default, do nothing.
  }

  /**
   * Invoked before [[RpcEndpoint]] starts to handle any message.
   */
  def onStart(): Unit = {
    // By default, do nothing.
  }

  /**
   * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
   * use it to send or ask messages.
   */
  def onStop(): Unit = {
    // By default, do nothing.
  }

  /**
   * A convenient method to stop [[RpcEndpoint]].
   */
  final def stop(): Unit = {
    val _self = self
    if (_self != null) {
      rpcEnv.stop(_self)
    }
  }
}

其中onStart()方法指将在处理message前invoke,接下来则是receive和receiveAndReply方法,它们在收到RpcEndpointRef的send()方法或ask()方法时分别触发,最后则是onStop()方法结束它的生命周期。所以其调用的顺序是

driverEndpoint(RpcEndpointRef):Banked.start -> createDriverEndpointRef(创建)-> rpcEnv.setupEndpoint(设置接收message得Endpoint)

DriverEndpoint(DriverEndpoint):Banked.start -> 通过createDriverEndpointRef内createDriverEndpoint创建 -> message来时onStart -> receive或receiveAndReply -> onStop()

那么了解了他们的调用顺序,回到reviveOffers()方法,其driverEndpoint向DriverEndpoint send的是ReviveOffers信息,那我们找到DriverEndpoint内的receive(),可以看到这段代码

 case ReviveOffers =>
        makeOffers()

4、makeOffer()方法与资源分配

点进makeOffer()方法,可以看到如下代码

private def makeOffers() {
      // Make sure no executor is killed while some task is launching on it
      val taskDescs = withLock {
        // Filter out executors under killing
        val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
        val workOffers = activeExecutors.map {
          case (id, executorData) =>
            new WorkerOffer(id, executorData.executorHost, executorData.freeCores,
              Some(executorData.executorAddress.hostPort))
        }.toIndexedSeq
        scheduler.resourceOffers(workOffers)
      }
      if (!taskDescs.isEmpty) {
        launchTasks(taskDescs)
      }
    }

该方法主要是为了给executor分配计算资源,首先是taskDescs变量,它是withLock()方法得返回值,withLock()方法是一个加锁操作。注释里写的很清楚,它要确保在task launching时没有executor被kill或者正处在被killing。

首先通过executorDataMap.filterKeys()方法过滤到只剩下alive的executor,剩下通过resourceOffers()方法对worker进行resource的分配。

最后在没有问题的情况下执行launchTasks(taskDescs),其代码内部如下

private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        val serializedTask = TaskDescription.encode(task)
        if (serializedTask.limit() >= maxRpcMessageSize) {
          Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach { taskSetMgr =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
                "spark.rpc.message.maxSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize)
              taskSetMgr.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          val executorData = executorDataMap(task.executorId)
          executorData.freeCores -= scheduler.CPUS_PER_TASK

          logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
            s"${executorData.executorHost}.")

          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

首先做个对ask serialized后message的超长判断,默认为128M,之后进行executorData的创建并更新一下资源信息,最后通过executorEndpoint.send方法发送message

那该executorEndpoint的发信对象是谁呢,点进LaunchTask这个msg对象,内部有一段注释

// Driver to executors
  case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage

Driver to executors

从这里我们知道该msg是发给executor的,我们找到org.apache.spark.executor.CoarseGrainedExecutorBackend,这个类的作用和CoarseGrainedSchedulerBackend类似,都是用来管理资源的,我们可以看它的receive()方法内有

case LaunchTask(data) =>
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        executor.launchTask(this, taskDesc)
      }

首先对task进行解码,并且执行executor.launchTask(this, taskDesc),至此,task才真正开始被执行,launchTask()方法内部如下

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    val tr = new TaskRunner(context, taskDescription)
    runningTasks.put(taskDescription.taskId, tr)
    threadPool.execute(tr)
  }

就是将task创建为一个TaskRunner并放入线程池中等待执行,具体执行的流程在override的run()方法内部,在执行结束后,会有结果和其他信息返回driver段,至此spark任务调用基本主体到此就结束了。

后话

本人刚入行不久,本文主要也以记录自己的学习过程为主,大部分是自己通过查阅资料和动手尝试下自己的见解,若有不对的地方欢迎指出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值