Spark—通过集合创建RDD指定分区数源码分析
首先来看一下通过集合创建RDD指定分区数的代码:
object test03_RDDmem {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test03_RDDmem")
val sc: SparkContext = new SparkContext(conf)
//通过集合创建RDD
val listRDD: RDD[Int] = sc.makeRDD(List(1,2,3,4),3)
//保存RDD到本地
listRDD.saveAsTextFile("E:\\IDEAworkspace\\bigdata-MrG\\spark-2021\\output")
}
}
实际上本地生成的文件内容如下:
分区1:1
分区2:2
分区3:3,4
1、我们首先来看看makeRDD的源代码:
/** Distribute a local Scala collection to form an RDD.
*
* This method is identical to `parallelize`.
*/
def makeRDD[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
parallelize(seq, numSlices)
}
可以看到我们传的参数有两个:
seq:也就是我们传入的集合
numSlice:也就是我们指定的分区数,如果不传会有一个默认值defaultParallelism,我们这边实际上是传入了3的。
而且makeRDD的底层实际上也是调用的parallelize方法,所以继续追parallelize的源码
2、parallelize源码如下:
/** Distribute a local Scala collection to form an RDD.
*
* @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
* to parallelize and before the first action on the RDD, the resultant RDD will reflect the
* modified collection. Pass a copy of the argument to avoid this.
* @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an
* RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions.
*/
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
可以看到实际上parallelize方法跟makeRDD一样也设置了numSlices的默认值,我们重点关注一下他的返回值
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
所以我们继续追ParallelCollectionRDD的源码
3、ParallelCollectionRDD源码如下:
private[spark] class ParallelCollectionRDD[T: ClassTag](
sc: SparkContext,
@transient private val data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
}
可以看到这个ParallelCollectionRDD也是继承RDD的,也重写了getPartitions,compute,getPreferredLocations,其中看getPartitions方法
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
我们继续跟slice这个方法
4、slice源码如下:
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
* is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
throw new IllegalArgumentException("Positive number of slices required")
}
// Sequences need to be sliced at the same set of index positions for operations
// like RDD.zip() to behave as expected
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
(0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
}
}
seq match {
case r: Range =>
positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
// If the range is inclusive, use inclusive range for the last slice
if (r.isInclusive && index == numSlices - 1) {
new Range.Inclusive(r.start + start * r.step, r.end, r.step)
}
else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}
}.toSeq.asInstanceOf[Seq[Seq[T]]]
case nr: NumericRange[_] =>
// For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
var r = nr
for ((start, end) <- positions(nr.length, numSlices)) {
val sliceSize = end - start
slices += r.take(sliceSize).asInstanceOf[Seq[T]]
r = r.drop(sliceSize)
}
slices
case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
}.toSeq
}
}
这段源码首先出现了一个判断,如果numSlice小于1,会抛出异常,然后是定义了一个内部函数positions,然后是一个seq的模式匹配,由于我们这边传的是list,所以直接看下边这种情况
case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
}.toSeq
这里边又调用了内部函数positions,并且传入了集合的长度,和numSlice,再来看看positions,他返回一个迭代器对象:
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
(0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
}
}
以我这个代码为例,集合list(1,2,3,4)长度是4,numSlice指定的是3;
所以在positions里边的执行,
第一次 i 等于0,start=0*4/3=0,end=(0+1)*4/3=1 (0,1)
第二次 i 等于1,start=1*4/3=1,end=(1+1)*4/3=2 (1,2)
第二次 i 等于2,start=2*4/3=2,end=(2+1)*4/3=4 (2,4)
所以在本次调用里实际上内容是(0,1)(1,2) (2,4)
然后把这个返回值传回
case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
}.toSeq
执行map里的逻辑, array.slice(start, end).toSeq,然后看一下slice的源码
5、slice源码如下:
def slice(from: Int, until: Int): Repr = {
val lo = math.max(from, 0)
val hi = math.min(math.max(until, 0), length)
val elems = math.max(hi - lo, 0)
val b = newBuilder
b.sizeHint(elems)
var i = lo
while (i < hi) {
b += self(i)
i += 1
}
b.result()
}
在这里,我们的array对象内容应当是(1,2,3,4),而第一次调用slice方法也就是(1,2,3,4).slice(0,1):
也就是 lo = 0,hi = 1,elems也就是1
其实就是计算出起始下标lo,结束下标hi,元素个数elems
然后创建一个 Builder对象,使用sizeHint初始化其容量,来看一下sizeHint的源码(这里是ArrayBuilder的):
private var elems: Array[T] = _
private var capacity: Int = 0
private var size: Int = 0
private def mkArray(size: Int): Array[T] = {
val newelems = new Array[T](size)
if (this.size > 0) Array.copy(elems, 0, newelems, 0, this.size)
newelems
}
private def resize(size: Int) {
elems = mkArray(size)
capacity = size
}
override def sizeHint(size: Int) {
if (capacity < size) resize(size)
}
如果capacity(原来的容量)小于size(本次传入的大小),那就resize(重置容量),
由于b刚刚初始化的时候capacity=0,而本次的size 是 1.所以执行resize。
而resize里调用的又有mkArray方法,
mkArray也就是根据新传入的size,新建一个长度为size的空数组。
如果b的size大于>0,则新建一个数组,然后从老的数组里复制过来,返回这个全新容量的数组给 elems。
但是我们第一次执行这个方法的时候, this.size=0,所以只是新建了一个长度为size的空数组。
以后resize方法会将capacity属性重置新值。此时capacity=1;b的size=0;
然后我们回到slice方法,
def slice(from: Int, until: Int): Repr = {
val lo = math.max(from, 0)
val hi = math.min(math.max(until, 0), length)
val elems = math.max(hi - lo, 0)
val b = newBuilder
b.sizeHint(elems)
var i = lo
while (i < hi) {
b += self(i)
i += 1
}
b.result()
}
继续往下执行 i =0,往b中添加元素self(i)也就是1,看一下 += 的源码:
def +=(elem: T): this.type = {
ensureSize(size + 1)
elems(size) = elem
size += 1
this
}
可以看到实际上是将elem就是传入的元素,给到elems,但是要执行一个ensureSize,
private def ensureSize(size: Int) {
if (capacity < size || capacity == 0) {
var newsize = if (capacity == 0) 16 else capacity * 2
while (newsize < size) newsize *= 2
resize(newsize)
}
}
if (capacity < size || capacity == 0)并不成立,所以这段不执行。
执行 elems(0) = 1; size + 1 。此时 capacity = 1,size =1;
然后i+=1;i不小于hi了,退出while;
最后是result,来看源码:
def result() = {
if (capacity != 0 && capacity == size) elems
else mkArray(size)
}
也就是当b的capacity 不为0且capacity =size的时候,返回elems;
也就是返回了 1 ;
然后再看第二次执行,第二次是slice (1,2)。其实跟(0,1)基本 没有区别,因为长度都是1;
第三次也类似过程,其实从slice的参数名可以看出, from 跟 until,前闭后开,然后根据源码这边来看,这个from 跟 until都是数组元素的下标。