makeRDD源码解析
// 返回ParallelCollectionRDD
def makeRDD[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
parallelize(seq, numSlices)
}
//这里分区数numSlices参数进行了初始化,如果没传入该参数就会是初始化的默认值
//将代码块{parallelize(seq, numSlices)}作为参数传给withScope调用
val rdd1: RDD[Int] = sparkContext.makeRDD(list)
val rdd2: RDD[Int] = sparkContext.parallelize(list)
//这两个创建RDD的方式等价,因为makeRDD调用的是parallelize方法
makeRDD方法实际上是将传入的集合和分区数两个参数传给parallelize方法然后将返回结果作为参数传给withScope方法调用,下面来分析一下parallelize的源码:
// 1.parallelize源码
/*
Distribute a local Scala collection to form an RDD
译文:将本地集合分发到RDD
*/
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
//Seq是传入的集合,numSlices是分区数,这里分区数是默认值defaultParallelism,以下来分析defaultParallelism源码:
点开找到重写方法:
override def defaultParallelism(): Int = backend.defaultParallelism()
再点开:def defaultParallelism(): Int ctrl+H看一下实现类,点击LocalSchedulerBackend
在该类里找到源码如下:
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
"spark.default.parallelism"是默认并行度,会从配置conf中获取,获取不到,则为totalCores
所以可以得出结论:
若在调用makeRDD方法时,传入了参数作为分区数,那么该参数值是最终分区数,若没有传参数,则分区数为默认值
totalCores,那么totalCores是多少呢?举个例子:在本地模式local环境下,
setMaster(local): 此时分区数默认为1
setMaster(local[n]) 此时分区数默认为n
setMaster(local[*]) 此时分区数默认为当前环境下最大核数
(假如电脑为8核16线程,那么local[*]默认分区数为16)
在调用saveAsTextFile方法生成文件时,每个分区对应一个文件
// 2.parallelize调用ParallelCollectionRDD,ParallelCollectionRDD伴生类源码如下
private[spark] class ParallelCollectionRDD[T: ClassTag](
sc: SparkContext,
@transient private val data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
}
//3.parallelize实际调用的伴生对象ParallelCollectionRDD源码:
private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
* is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
//检查分区数是否合法,设置分区数小于1则报异常
if (numSlices < 1) {
throw new IllegalArgumentException("Positive number of partitions required")
}
// Sequences need to be sliced at the same set of index positions for operations
// like RDD.zip() to behave as expected
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
(0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
}
}
//这里模式匹配判断传入的Seq集合是不是range类
seq match {
case r: Range =>
positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
// If the range is inclusive, use inclusive range for the last slice
if (r.isInclusive && index == numSlices - 1) {
new Range.Inclusive(r.start + start * r.step, r.end, r.step)
}
else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}
}.toSeq.asInstanceOf[Seq[Seq[T]]]
case nr: NumericRange[_] =>
// For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
var r = nr
for ((start, end) <- positions(nr.length, numSlices)) {
val sliceSize = end - start
slices += r.take(sliceSize).asInstanceOf[Seq[T]]
r = r.drop(sliceSize)
}
slices
//不是range类则执行这一条
case _ =>
//先将集合转换成Array集合
val array = seq.toArray // To prevent O(n^2) operations for List etc
//这里将集合的长度和分区数作为参数调用上面的positions方法,返回一个元素为(start, end)的迭代器,迭代器中元素个数为分区数。然后迭代器中每个元组(start,end)调用slice方法来切分数组
positions(array.length, numSlices).map { case (start, end) =>
//这里的slice方法是切分数组=>(from,until)
array.slice(start, end).toSeq
}.toSeq
}
}
}
// 4.slice方法源码如下:
override def slice(from: Int, until: Int): Array[T] = {
val lo = math.max(from, 0)
val hi = math.min(math.max(until, 0), repr.length)
val size = math.max(hi - lo, 0)
val result = java.lang.reflect.Array.newInstance(elementClass, size)
if (size > 0) {
Array.copy(repr, lo, result, 0, size)
}
result.asInstanceOf[Array[T]]
}
/*
slice切分数组规则如下:
例如:设置分区数为3,List(1,2,3,4,5)经过position方法后,变成Array((0,1),(1,3),(3,5))
那么Array调用map后,Array中每个元组调用slice方法,首先是(0,1),List(1,2,3,4,5)从0号索引开始不包含1,也就是1 放在第一个分区,然后(1,3),List(1,2,3,4,5)从1号索引位置开始不包含3,那么就是2,3放在第二个分区,然后是(3,5),List(1,2,3,4,5)从3号索引位置开始不包含5号索引位置的数字,即4,5放在第三个分区内
*/
最后看一下withScope方法源码:
// 5.withScope源码:不用过分关注,parallelize返回RDD对象然后作为参数调用withScope返回RDD
/**
* Execute the given body such that all RDDs created in this body will have the same scope.
译文:执行给定的主体,使在该主体中创建的所有RDDs具有相同的作用域
* The name of the scope will be the first method name in the stack trace that is not the same as this method's.
译文:作用域的名称将是堆栈跟踪中与此方法不同的第一个方法名称。
* Note: Return statements are NOT allowed in body.
*/
private[spark] def withScope[T](
sc: SparkContext,
allowNesting: Boolean = false)(body: => T): T = {
val ourMethodName = "withScope"
val callerMethodName = Thread.currentThread.getStackTrace()
.dropWhile(_.getMethodName != ourMethodName)
.find(_.getMethodName != ourMethodName)
.map(_.getMethodName)
.getOrElse {
// Log a warning just in case, but this should almost certainly never happen
logWarning("No valid method name for this RDD operation scope!")
"N/A"
}
withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body)
}
总结: