checkpoint
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
RDDCheckpointData类是和RDD一一对应的,保存着一切和RDD checkpoint相关的所有信息,而且具体的Checkpoint操作都是它和它的子类进行的。
private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
extends Serializable {
import CheckpointState._
#初始化状态
protected var cpState = Initialized
# 包含CheckpointRDD的RDD
private var cpRDD: Option[CheckpointRDD[T]] = None
/**
* Return whether the checkpoint data for this RDD is already persisted.
*/
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
cpState == Checkpointed
}
/**
* Materialize this RDD and persist its content.
* This is called immediately after the first action invoked on this RDD has completed.
*/
final def checkpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
if (cpState == Initialized) {
cpState = CheckpointingInProgress
} else {
return
}
}
val newRDD = doCheckpoint()
// Update our state and truncate the RDD lineage
RDDCheckpointData.synchronized {
cpRDD = Some(newRDD)
cpState = Checkpointed
rdd.markCheckpointed()
}
}
/**
* Materialize this RDD and persist its content.
*
* Subclasses should override this method to define custom checkpointing behavior.
* @return the checkpoint RDD created in the process.
*/
protected def doCheckpoint(): CheckpointRDD[T]
/**
* Return the RDD that contains our checkpointed data.
* This is only defined if the checkpoint state is `Checkpointed`.
*/
def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD }
/**
* Return the partitions of the resulting checkpoint RDD.
* For tests only.
*/
def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
cpRDD.map(_.partitions).getOrElse { Array.empty }
}
}
在该类中,有表示check点状态的枚举变量:用于表示RDD检查点的当前状态,其值有Initialized 、CheckpointingInProgress、 checkpointed。
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
val Initialized, CheckpointingInProgress, Checkpointed = Value
}
checkpoint要工作需要经过以下几个阶段:
Initialized –> CheckpointingInProgress–> Checkpointed
回到RDD核心代码中的checkpoints
一 checkpoint写流程
1 RDD中的checkpoint方法
首先正确的使用方式:
val data = sc.textFile("/tmp/spark/1.data").cache() /// 注意要cache
sc.setCheckpointDir("/tmp/spark/checkpoint")
data.checkpoint
data.count
可以看出data.checkpoint 是在cache后action操作之前:
def checkpoint(): Unit = RDDCheckpointData.synchronized {
// NOTE: we use a global lock here due to complexities downstream with ensuring
// children RDD partitions point to the correct parent partitions. In the future
// we should revisit this consideration.
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new ReliableRDDCheckpointData(this))
}
}
该方法初始化了ReliableRDDCheckpointData对象,现在checkpointData 变成了ReliableRDDCheckpointData对象。
注:该方法需要在action之前调用
转入ReliableRDDCheckpointData中一探究竟:
private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
extends RDDCheckpointData[T](rdd) with Logging
首先该类继承RDDCheckpointData:
类中将状态初始化为Initialized
protected var cpState = Initialized
之后checkpoint的信息都是由ReliableRDDCheckpointData管理
此时只是做一个标记,并不执行checkpoint的操作。
2、action后调用 rdd.doCheckpoint
在所有 job action 的时候, runJob 方法中都会调用 rdd.doCheckpoint , 这个会向前递归调用所有的依赖的RDD, 看看需不需要 checkpoint 。
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit): Unit = {
...//此处省略一万字
rdd.doCheckpoint()
}
RDD中的doCheckpoint
private[spark] def doCheckpoint(): Unit = {
RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) {
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
if (checkpointAllMarkedAncestors) {
// 顺着依赖关系递归执行该方法
dependencies.foreach(_.rdd.doCheckpoint())
}
checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
}
}
}
如果之前调用了data.checkpoint ,那么checkpointData.isDefined就是true了。
于是就顺着依赖序列不断调用doCheckpoint方法;
然后调用checkpointData.get.checkpoint()
final def checkpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
if (cpState == Initialized) {#如果cpState 是CheckpointingInProgress表示有人在对该RDD进行check
cpState = CheckpointingInProgress
} else {
return
}
}
现在 cpState = CheckpointingInProgress,
之后实际调用RDDCheckpointData中的doCheckpoint
val newRDD = doCheckpoint()
子类需要override该方法,对于ReliableRDDCheckpointData
protected override def doCheckpoint(): CheckpointRDD[T] = {
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
主要工作是调用writeRDDToCheckpointDirectory,将RDD写入hdfs中,写完后还会返回newRDD,并最后赋值给cpRDD
3、Checkpointed
最后:
RDDCheckpointData.synchronized {
cpRDD = Some(newRDD)
cpState = Checkpointed
rdd.markCheckpointed()
}
清除依赖关系
private[spark] def markCheckpointed(): Unit = {
clearDependencies()
partitions_ = null
deps = null // Forget the constructor argument for dependencies too
}
为什么要清除依赖关系呢?
如果是迭代计算,总共计算1000次,最终在999次的时候发送错误挂掉,那么如果依赖存在,就会顺着该依赖把RDD在重新迭代计算一遍,如果清除依赖,之前的RDD已经存储了,而且可以使用persist操作持久在内存,更加高效地直接在第999次的基础上迭代计算。
二、checkpoint 读流程
当需要计算某个RDD的时候,先看看依赖关系中是不是有checkpointRDD中的依赖关系。如果被checkpoint过,那么RDD的依赖就转至checkpointRDD了,
final def dependencies: Seq[Dependency[_]] = {
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
if (dependencies_ == null) {
dependencies_ = getDependencies
}
dependencies_
}
}
之后实际在计算RDD的时候调用iterator:
首先判断是否被
computeOrReadCheckpoint(split, context)
根据是不是被checkpoint过,选择是否重新计算还是调用依赖中的CheckpointRDD。如果被check过了,那么firstParent就是CheckpointRDD,然后调用CheckpointRDD的迭代器计算
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)
}
如果被checkpoint过,将会调用ReliableCheckpointRDD中的计算
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))
ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context)
}
readCheckpointFile,将之前写的checkpoint信息读出来:
def readCheckpointFile[T](
path: Path,
broadcastedConf: Broadcast[SerializableConfiguration],
context: TaskContext): Iterator[T] = {
val env = <span class="wp_keywordlink_affiliate"><a href="https://www.iteblog.com/archives/tag/spark/" title="" target="_blank" data-original-title="View all posts in Spark">Spark</a></span>Env.get
val fs = path.getFileSystem(broadcastedConf.value.value)
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => deserializeStream.close())
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
于是数据就被读回来了。