Spark---通过集合创建RDD指定分区数源码分析

Spark—通过集合创建RDD指定分区数源码分析

首先来看一下通过集合创建RDD指定分区数的代码:

object test03_RDDmem {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test03_RDDmem")

    val sc: SparkContext = new SparkContext(conf)

    //通过集合创建RDD
    val listRDD: RDD[Int] = sc.makeRDD(List(1,2,3,4),3)

    //保存RDD到本地
    listRDD.saveAsTextFile("E:\\IDEAworkspace\\bigdata-MrG\\spark-2021\\output")

  }

}

实际上本地生成的文件内容如下:

分区1:1
分区2:2
分区3:34

1、我们首先来看看makeRDD的源代码:

  /** Distribute a local Scala collection to form an RDD.
   *
   * This method is identical to `parallelize`.
   */
  def makeRDD[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    parallelize(seq, numSlices)
  }

可以看到我们传的参数有两个:
seq:也就是我们传入的集合
numSlice:也就是我们指定的分区数,如果不传会有一个默认值defaultParallelism,我们这边实际上是传入了3的。
而且makeRDD的底层实际上也是调用的parallelize方法,所以继续追parallelize的源码
2、parallelize源码如下:

 /** Distribute a local Scala collection to form an RDD.
   *
   * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
   * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
   * modified collection. Pass a copy of the argument to avoid this.
   * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an
   * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions.
   */
  def parallelize[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    assertNotStopped()
    new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
  }

可以看到实际上parallelize方法跟makeRDD一样也设置了numSlices的默认值,我们重点关注一下他的返回值

new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())

所以我们继续追ParallelCollectionRDD的源码
3、ParallelCollectionRDD源码如下:

private[spark] class ParallelCollectionRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val data: Seq[T],
    numSlices: Int,
    locationPrefs: Map[Int, Seq[String]])
    extends RDD[T](sc, Nil) {
  // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
  // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
  // instead.
  // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.

  override def getPartitions: Array[Partition] = {
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
  }

  override def getPreferredLocations(s: Partition): Seq[String] = {
    locationPrefs.getOrElse(s.index, Nil)
  }
}

可以看到这个ParallelCollectionRDD也是继承RDD的,也重写了getPartitions,compute,getPreferredLocations,其中看getPartitions方法

  override def getPartitions: Array[Partition] = {
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
  }

我们继续跟slice这个方法
4、slice源码如下:

  /**
   * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
   * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
   * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
   * is an inclusive Range, we use inclusive range for the last slice.
   */
  def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    if (numSlices < 1) {
      throw new IllegalArgumentException("Positive number of slices required")
    }
    // Sequences need to be sliced at the same set of index positions for operations
    // like RDD.zip() to behave as expected
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
      (0 until numSlices).iterator.map { i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }
    seq match {
      case r: Range =>
        positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
          // If the range is inclusive, use inclusive range for the last slice
          if (r.isInclusive && index == numSlices - 1) {
            new Range.Inclusive(r.start + start * r.step, r.end, r.step)
          }
          else {
            new Range(r.start + start * r.step, r.start + end * r.step, r.step)
          }
        }.toSeq.asInstanceOf[Seq[Seq[T]]]
      case nr: NumericRange[_] =>
        // For ranges of Long, Double, BigInteger, etc
        val slices = new ArrayBuffer[Seq[T]](numSlices)
        var r = nr
        for ((start, end) <- positions(nr.length, numSlices)) {
          val sliceSize = end - start
          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
          r = r.drop(sliceSize)
        }
        slices
      case _ =>
        val array = seq.toArray // To prevent O(n^2) operations for List etc
        positions(array.length, numSlices).map { case (start, end) =>
            array.slice(start, end).toSeq
        }.toSeq
    }
  }

这段源码首先出现了一个判断,如果numSlice小于1,会抛出异常,然后是定义了一个内部函数positions,然后是一个seq的模式匹配,由于我们这边传的是list,所以直接看下边这种情况

  case _ =>
        val array = seq.toArray // To prevent O(n^2) operations for List etc
        positions(array.length, numSlices).map { case (start, end) =>
            array.slice(start, end).toSeq
        }.toSeq

这里边又调用了内部函数positions,并且传入了集合的长度,和numSlice,再来看看positions,他返回一个迭代器对象:

    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
      (0 until numSlices).iterator.map { i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }

以我这个代码为例,集合list(1,2,3,4)长度是4,numSlice指定的是3;
所以在positions里边的执行,
第一次 i 等于0,start=0*4/3=0,end=(0+1)*4/3=1 (0,1)

第二次 i 等于1,start=1*4/3=1,end=(1+1)*4/3=2 (1,2)

第二次 i 等于2,start=2*4/3=2,end=(2+1)*4/3=4 (2,4)

所以在本次调用里实际上内容是(0,1)(1,2) (2,4)

然后把这个返回值传回

  case _ =>
        val array = seq.toArray // To prevent O(n^2) operations for List etc
        positions(array.length, numSlices).map { case (start, end) =>
            array.slice(start, end).toSeq
        }.toSeq

执行map里的逻辑, array.slice(start, end).toSeq,然后看一下slice的源码
5、slice源码如下:

  def slice(from: Int, until: Int): Repr = {
    val lo    = math.max(from, 0)
    val hi    = math.min(math.max(until, 0), length)
    val elems = math.max(hi - lo, 0)
    val b     = newBuilder
    b.sizeHint(elems)

    var i = lo
    while (i < hi) {
      b += self(i)
      i += 1
    }
    b.result()
  }

在这里,我们的array对象内容应当是(1,2,3,4),而第一次调用slice方法也就是(1,2,3,4).slice(0,1):
也就是 lo = 0,hi = 1,elems也就是1
其实就是计算出起始下标lo,结束下标hi,元素个数elems

然后创建一个 Builder对象,使用sizeHint初始化其容量,来看一下sizeHint的源码(这里是ArrayBuilder的):

 
    private var elems: Array[T] = _
    private var capacity: Int = 0
    private var size: Int = 0

    private def mkArray(size: Int): Array[T] = {
      val newelems = new Array[T](size)
      if (this.size > 0) Array.copy(elems, 0, newelems, 0, this.size)
      newelems
    }

    private def resize(size: Int) {
      elems = mkArray(size)
      capacity = size
    }

    override def sizeHint(size: Int) {
      if (capacity < size) resize(size)
    }

如果capacity(原来的容量)小于size(本次传入的大小),那就resize(重置容量),
由于b刚刚初始化的时候capacity=0,而本次的size 是 1.所以执行resize。

而resize里调用的又有mkArray方法,
mkArray也就是根据新传入的size,新建一个长度为size的空数组。

如果b的size大于>0,则新建一个数组,然后从老的数组里复制过来,返回这个全新容量的数组给 elems。
但是我们第一次执行这个方法的时候, this.size=0,所以只是新建了一个长度为size的空数组。

以后resize方法会将capacity属性重置新值。此时capacity=1;b的size=0;
然后我们回到slice方法,

  def slice(from: Int, until: Int): Repr = {
    val lo    = math.max(from, 0)
    val hi    = math.min(math.max(until, 0), length)
    val elems = math.max(hi - lo, 0)
    val b     = newBuilder
    b.sizeHint(elems)

    var i = lo
    while (i < hi) {
      b += self(i)
      i += 1
    }
    b.result()
  }

继续往下执行 i =0,往b中添加元素self(i)也就是1,看一下 += 的源码:

  def +=(elem: T): this.type = {
      ensureSize(size + 1)
      elems(size) = elem
      size += 1
      this
    }

可以看到实际上是将elem就是传入的元素,给到elems,但是要执行一个ensureSize,

    private def ensureSize(size: Int) {
      if (capacity < size || capacity == 0) {
        var newsize = if (capacity == 0) 16 else capacity * 2
        while (newsize < size) newsize *= 2
        resize(newsize)
      }
    }

if (capacity < size || capacity == 0)并不成立,所以这段不执行。

执行 elems(0) = 1; size + 1 。此时 capacity = 1,size =1;

然后i+=1;i不小于hi了,退出while;
最后是result,来看源码:

 def result() = {
      if (capacity != 0 && capacity == size) elems
      else mkArray(size)
    }

也就是当b的capacity 不为0且capacity =size的时候,返回elems;
也就是返回了 1 ;

然后再看第二次执行,第二次是slice (1,2)。其实跟(0,1)基本 没有区别,因为长度都是1;
第三次也类似过程,其实从slice的参数名可以看出, from 跟 until,前闭后开,然后根据源码这边来看,这个from 跟 until都是数组元素的下标。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值