Spark中makeRDD源码分析

makeRDD源码解析

// 返回ParallelCollectionRDD
def makeRDD[T: ClassTag](
      seq: Seq[T],   
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    parallelize(seq, numSlices)
  }
//这里分区数numSlices参数进行了初始化,如果没传入该参数就会是初始化的默认值
//将代码块{parallelize(seq, numSlices)}作为参数传给withScope调用
val rdd1: RDD[Int] = sparkContext.makeRDD(list)
    
val rdd2: RDD[Int] = sparkContext.parallelize(list)

//这两个创建RDD的方式等价,因为makeRDD调用的是parallelize方法

makeRDD方法实际上是将传入的集合和分区数两个参数传给parallelize方法然后将返回结果作为参数传给withScope方法调用,下面来分析一下parallelize的源码:

// 1.parallelize源码
/*
Distribute a local Scala collection to form an RDD
译文:将本地集合分发到RDD
*/
def parallelize[T: ClassTag](
      seq: Seq[T],
      numSlices: Int = defaultParallelism): RDD[T] = withScope {
    assertNotStopped()
    new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
  }

//Seq是传入的集合,numSlices是分区数,这里分区数是默认值defaultParallelism,以下来分析defaultParallelism源码:

点开找到重写方法:
override def defaultParallelism(): Int = backend.defaultParallelism()
再点开:def defaultParallelism(): Int  ctrl+H看一下实现类,点击LocalSchedulerBackend
在该类里找到源码如下:
override def defaultParallelism(): Int =
    scheduler.conf.getInt("spark.default.parallelism", totalCores)

"spark.default.parallelism"是默认并行度,会从配置conf中获取,获取不到,则为totalCores

所以可以得出结论:
若在调用makeRDD方法时,传入了参数作为分区数,那么该参数值是最终分区数,若没有传参数,则分区数为默认值
totalCores,那么totalCores是多少呢?举个例子:在本地模式local环境下,
		setMaster(local):       此时分区数默认为1
		setMaster(local[n])      此时分区数默认为n
		setMaster(local[*])      此时分区数默认为当前环境下最大核数
(假如电脑为816线程,那么local[*]默认分区数为16)
在调用saveAsTextFile方法生成文件时,每个分区对应一个文件

// 2.parallelize调用ParallelCollectionRDD,ParallelCollectionRDD伴生类源码如下
private[spark] class ParallelCollectionRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val data: Seq[T],
    numSlices: Int,
    locationPrefs: Map[Int, Seq[String]])
    extends RDD[T](sc, Nil) {
  // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
  // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
  // instead.
  // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.

  override def getPartitions: Array[Partition] = {
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
  }

  override def getPreferredLocations(s: Partition): Seq[String] = {
    locationPrefs.getOrElse(s.index, Nil)
  }
}

//3.parallelize实际调用的伴生对象ParallelCollectionRDD源码:
private object ParallelCollectionRDD {
  /**
   * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
   * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
   * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
   * is an inclusive Range, we use inclusive range for the last slice.
   */
  def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    //检查分区数是否合法,设置分区数小于1则报异常
    if (numSlices < 1) {
      throw new IllegalArgumentException("Positive number of partitions required")
    }
    // Sequences need to be sliced at the same set of index positions for operations
    // like RDD.zip() to behave as expected
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
      (0 until numSlices).iterator.map { i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }
  
    //这里模式匹配判断传入的Seq集合是不是range类
    seq match {
      case r: Range =>
        positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
          // If the range is inclusive, use inclusive range for the last slice
          if (r.isInclusive && index == numSlices - 1) {
            new Range.Inclusive(r.start + start * r.step, r.end, r.step)
          }
          else {
            new Range(r.start + start * r.step, r.start + end * r.step, r.step)
          }
        }.toSeq.asInstanceOf[Seq[Seq[T]]]
      case nr: NumericRange[_] =>
        // For ranges of Long, Double, BigInteger, etc
        val slices = new ArrayBuffer[Seq[T]](numSlices)
        var r = nr
        for ((start, end) <- positions(nr.length, numSlices)) {
          val sliceSize = end - start
          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
          r = r.drop(sliceSize)
        }
        slices
      //不是range类则执行这一条
      case _ =>
        //先将集合转换成Array集合
        val array = seq.toArray // To prevent O(n^2) operations for List etc
       
        //这里将集合的长度和分区数作为参数调用上面的positions方法,返回一个元素为(start, end)的迭代器,迭代器中元素个数为分区数。然后迭代器中每个元组(start,end)调用slice方法来切分数组
        positions(array.length, numSlices).map { case (start, end) =>
            //这里的slice方法是切分数组=>(from,until)
            array.slice(start, end).toSeq
        }.toSeq
    }
  }
}
// 4.slice方法源码如下:
 override def slice(from: Int, until: Int): Array[T] = {
     val lo = math.max(from, 0)
     val hi = math.min(math.max(until, 0), repr.length)
     val size = math.max(hi - lo, 0)
     val result = java.lang.reflect.Array.newInstance(elementClass, size)
     if (size > 0) {
      Array.copy(repr, lo, result, 0, size)
     }
     result.asInstanceOf[Array[T]]
  }
/*
slice切分数组规则如下:
例如:设置分区数为3,List(1,2,3,4,5)经过position方法后,变成Array((0,1),(1,3),(3,5))
那么Array调用map后,Array中每个元组调用slice方法,首先是(0,1),List(1,2,3,4,5)从0号索引开始不包含1,也就是1 放在第一个分区,然后(1,3),List(1,2,3,4,5)从1号索引位置开始不包含3,那么就是2,3放在第二个分区,然后是(3,5),List(1,2,3,4,5)从3号索引位置开始不包含5号索引位置的数字,即4,5放在第三个分区内
*/

最后看一下withScope方法源码:

// 5.withScope源码:不用过分关注,parallelize返回RDD对象然后作为参数调用withScope返回RDD
/**
   * Execute the given body such that all RDDs created in this body will have the same scope.
   译文:执行给定的主体,使在该主体中创建的所有RDDs具有相同的作用域
   * The name of the scope will be the first method name in the stack trace that is not the same as this method's.
   译文:作用域的名称将是堆栈跟踪中与此方法不同的第一个方法名称。
   * Note: Return statements are NOT allowed in body.
   */
  private[spark] def withScope[T](
      sc: SparkContext,
      allowNesting: Boolean = false)(body: => T): T = {
    val ourMethodName = "withScope"
    val callerMethodName = Thread.currentThread.getStackTrace()
      .dropWhile(_.getMethodName != ourMethodName)
      .find(_.getMethodName != ourMethodName)
      .map(_.getMethodName)
      .getOrElse {
        // Log a warning just in case, but this should almost certainly never happen
        logWarning("No valid method name for this RDD operation scope!")
        "N/A"
      }
    withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body)
  }

总结:

在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值