Spark RDD源码剖析

1.    RDD是一个抽象类,需要实现两个抽象函数

abstract class RDD[T: ClassTag](

    @transientprivate var _sc: SparkContext,

    @transientprivate var deps: Seq[Dependency[_]]

  ) extendsSerializable with Logging {

...

    def compute(split: Partition, context:TaskContext): Iterator[T] //输入分区,输出该分区的迭代器用于访问分区中的每一个类型为T的元素

 

    protected def getPartitions:Array[Partition]//返回RDD
中所有的分区信息

 

    val partitioner: Option[Partitioner] = None //指定分区器

...

}

1.1RDD 分区

trait Partition extends Serializable {

  defindex: Int

  overridedef hashCode(): Int = index

  overridedef equals(other: Any): Boolean = super.equals(other)

}

主要数据为分区索引号

1.2  RDD的分区器

abstract class Partitioner extends Serializable{

  defnumPartitions: Int

  defgetPartition(key: Any): Int

}

//主要数据为分区数,抽象函数为getPartition(key),Spark自带的分区器有两种:HashPartitioner以及RangePartitioner,Spark分区器能保证同一分区的数据在一个节点,HashPartitioner是对key计算哈希值并取模,模为指定分区数;RangePartitioner用于数据排序,比如sortByKey会触发一个Job,且是一个Shuffle过程

def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)

      : RDD[(K, V)] = self.withScope

  {

    val part = new RangePartitioner(numPartitions, self, ascending)

    newShuffledRDD[K, V, V](self, part)

      .setKeyOrdering(if (ascending)ordering else ordering.reverse)

  }

基本思想是:对父RDD的每个分区进行采样,对存在数据倾斜的分区进行重复采样,采样出来的数据进行排序,然后根据总分区数进行分区划分

1.3 RDD的依赖

分为宽依赖和窄依赖

宽依赖:父RDD的每个Partition会被子RDD的多个Partition使用

 

abstractclassDependency[T] extendsSerializable {

  def rdd: RDD[T]

}

窄依赖:父RDD的每个Partition最多被子RDD的一个Partition使用

abstractclassNarrowDependency[T](_rdd: RDD[T])extendsDependency[T] {

  defgetParents(partitionId: Int): Seq[Int] //根据子RDD的分区号得到父RDD的分区号

  overridedef rdd: RDD[T] = _rdd //RDD

}

分为三种窄依赖:分别为OneToOneDependency,RangeDependency,PruneDependency

OneToOneDependency:一对一依赖,子RDD的分区号与父RDD的分区号一致

classOneToOneDependency[T](rdd: RDD[T])extendsNarrowDependency[T](rdd) {

  overridedefgetParents(partitionId: Int): List[Int] = List(partitionId)

}

 

RangeDependency:union算子会产生RangeDependency,一个子RDD会持有多个RangeDependency对象

class RangeDependency[T](rdd: RDD[T],inStart: Int, outStart: Int, length: Int)

  extendsNarrowDependency[T](rdd) {

 

  override def getParents(partitionId:Int): List[Int] = {

    if (partitionId >=outStart && partitionId < outStart + length) {

      List(partitionId - outStart + inStart) //返回父RDD的分区号

    } else {

      Nil

    }

  }

}

 

2.     ParallelCollectionRDD的实现原理

sc.parallelize()实际上是创建一个ParallelCollectionRDD

def parallelize[T:ClassTag](

      seq: Seq[T],

      numSlices: Int = defaultParallelism):RDD[T] = withScope {

    assertNotStopped()

    new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int,Seq[String]]()) //最后一个参数为分区对应的最优位置

  }

 

private[spark] class ParallelCollectionRDD[T: ClassTag](

    sc: SparkContext,

    @transient private val data: Seq[T],

    numSlices: Int,

    locationPrefs: Map[Int, Seq[String]])

    extends RDD[T](sc, Nil) {

 

  override def getPartitions:Array[Partition] = {

    val slices =ParallelCollectionRDD.slice(data, numSlices).toArray

    slices.indices.map(i => newParallelCollectionPartition(id, i, slices(i))).toArray

  }

 

  override def compute(s: Partition,context: TaskContext): Iterator[T] = {

    newInterruptibleIterator(context,s.asInstanceOf[ParallelCollectionPartition[T]].iterator)

  }

 

  override def getPreferredLocations(s:Partition): Seq[String] = {

    locationPrefs.getOrElse(s.index, Nil)

  }

}

//ParallelCollectionRDD需要继承RDD,由于其为源头RDD,因此RDD[T](sc, Nil)ParallelCollectionRDD的构造参数分别为scseqnumSliceslocationPrefs。且需要实现两个抽象方法compute(),以及getPartitions()

compute():s转换成ParallelCollectionPartition[T]类型的对象并求得其迭代器,将此迭代器和sc作为参数创建InterruptibleIterators对应的迭代器

getPartitions():将数据集合data分成 numSlices个分片,将每个分片封装成一个ParallelCollectionPartition对象

 

private[spark] class ParallelCollectionPartition[T: ClassTag](

    var rddId: Long,

    var slice: Int,

    var values: Seq[T]

  ) extends Partitionwith Serializable{

 

  def iterator: Iterator[T] =values.iterator

 

  override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt

 

  override def equals(other: Any):Boolean = other match {

    case that:ParallelCollectionPartition[_] =>

      this.rddId == that.rddId&& this.slice == that.slice

    case _ => false

  }

 

  override def index: Int = slice

 

  @throws(classOf[IOException])

  private def writeObject(out:ObjectOutputStream): Unit = Utils.tryOrIOException {

    ...

  }

 

  @throws(classOf[IOException])

  private def readObject(in:ObjectInputStream): Unit = Utils.tryOrIOException {

    ...

  }

}

ParallelCollectionPartition的构造参数分别为rddIdRDD号),slice(分区号),values(分区数据),其中writeObject(out:ObjectOutputStream)为序列化方法,readObject(in: ObjectInputStream)为反序列化方法

 

3. MapPartitionsRDD

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](

    var prev: RDD[T],

    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)

    preservesPartitioning: Boolean = false)

  extends RDD[U](prev) {

 

  override val partitioner = if (preservesPartitioning)firstParent[T].partitioner else None

 

  override def getPartitions:Array[Partition] = firstParent[T].partitions

 

  override def compute(split: Partition,context: TaskContext): Iterator[U] =

    f(context, split.index,firstParent[T].iterator(split, context))

 

  override def clearDependencies() {

    super.clearDependencies()

    prev = null

  }

}

MapPartitionsRDD的构造参数为prev(父RDD),f(处理函数,输入taskContext,分区索引,父RDD分区迭代器)即对RDD中的每个分区的每个元素调用f进行处理生成新的元素,从而构成新的RDD

 

classLocalFileRddPartition(val split: Int, val startLine: Int, val endLine: Int)extendsPartition{

  overridedef index: Int = split

 

  @throws(classOf[IOException])

  privatedef writeObject(out:ObjectOutputStream): Unit =  {

      out.defaultWriteObject()

  }

 

  @throws(classOf[IOException])

  privatedef readObject(in:ObjectInputStream): Unit =  {

    in.defaultReadObject()

  }

}

 

classLocalFileRdd(sc: SparkContext,fileFullName: String,partitionsNum: Int = 4)extendsRDD[(Int,String)](sc,Nil)

{

  val lines =Source.fromFile(fileFullName)

              .getLines().toSeq

              .zipWithIndex.map(x=>(x._2,x._1))

 

  overridedef getPartitions:Array[Partition] = {

    val linesPerPartition= (lines.length+partitionsNum-1)/partitionsNum

    return  (0 to partitionsNum).map {

      split =>

        newLocalFileRddPartition(split, split * linesPerPartition, (split + 1) *linesPerPartition)

    }.toArray

 

  }

  overridedef compute(s:Partition, context: TaskContext): Iterator[(Int,String)] = {

    val partition =s.asInstanceOf[LocalFileRddPartition]

    lines.iterator.filter(x=>x._1 >=partition.startLine && x._1 < partition.endLine)

  }

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值