1. RDD是一个抽象类,需要实现两个抽象函数
abstract class RDD[T: ClassTag](
@transientprivate var _sc: SparkContext,
@transientprivate var deps: Seq[Dependency[_]]
) extendsSerializable with Logging {
...
def compute(split: Partition, context:TaskContext): Iterator[T] //输入分区,输出该分区的迭代器用于访问分区中的每一个类型为T的元素
protected def getPartitions:Array[Partition]//返回RDD
中所有的分区信息
val partitioner: Option[Partitioner] = None //指定分区器
...
}
1.1RDD 分区
trait Partition extends Serializable {
defindex: Int
overridedef hashCode(): Int = index
overridedef equals(other: Any): Boolean = super.equals(other)
}
主要数据为分区索引号
1.2 RDD的分区器
abstract class Partitioner extends Serializable{
defnumPartitions: Int
defgetPartition(key: Any): Int
}
//主要数据为分区数,抽象函数为getPartition(key),Spark自带的分区器有两种:HashPartitioner以及RangePartitioner,Spark分区器能保证同一分区的数据在一个节点,HashPartitioner是对key计算哈希值并取模,模为指定分区数;RangePartitioner用于数据排序,比如sortByKey会触发一个Job,且是一个Shuffle过程
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
: RDD[(K, V)] = self.withScope
{
val part = new RangePartitioner(numPartitions, self, ascending)
newShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending)ordering else ordering.reverse)
}
基本思想是:对父RDD的每个分区进行采样,对存在数据倾斜的分区进行重复采样,采样出来的数据进行排序,然后根据总分区数进行分区划分
1.3 RDD的依赖
分为宽依赖和窄依赖
宽依赖:父RDD的每个Partition会被子RDD的多个Partition使用
abstractclassDependency[T] extendsSerializable {
def rdd: RDD[T]
}
窄依赖:父RDD的每个Partition最多被子RDD的一个Partition使用
abstractclassNarrowDependency[T](_rdd: RDD[T])extendsDependency[T] {
defgetParents(partitionId: Int): Seq[Int] //根据子RDD的分区号得到父RDD的分区号
overridedef rdd: RDD[T] = _rdd //父RDD对象
}
分为三种窄依赖:分别为OneToOneDependency,RangeDependency,PruneDependency
OneToOneDependency:一对一依赖,子RDD的分区号与父RDD的分区号一致
classOneToOneDependency[T](rdd: RDD[T])extendsNarrowDependency[T](rdd) {
overridedefgetParents(partitionId: Int): List[Int] = List(partitionId)
}
RangeDependency:union算子会产生RangeDependency,一个子RDD会持有多个RangeDependency对象
class RangeDependency[T](rdd: RDD[T],inStart: Int, outStart: Int, length: Int)
extendsNarrowDependency[T](rdd) {
override def getParents(partitionId:Int): List[Int] = {
if (partitionId >=outStart && partitionId < outStart + length) {
List(partitionId - outStart + inStart) //返回父RDD的分区号
} else {
Nil
}
}
}
2. ParallelCollectionRDD的实现原理
sc.parallelize()实际上是创建一个ParallelCollectionRDD
def parallelize[T:ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism):RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int,Seq[String]]()) //最后一个参数为分区对应的最优位置
}
private[spark] class ParallelCollectionRDD[T: ClassTag](
sc: SparkContext,
@transient private val data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
override def getPartitions:Array[Partition] = {
val slices =ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => newParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition,context: TaskContext): Iterator[T] = {
newInterruptibleIterator(context,s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s:Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
}
//ParallelCollectionRDD需要继承RDD,由于其为源头RDD,因此RDD[T](sc, Nil);ParallelCollectionRDD的构造参数分别为sc,seq,numSlices,locationPrefs。且需要实现两个抽象方法compute(),以及getPartitions()
compute():将s转换成ParallelCollectionPartition[T]类型的对象并求得其迭代器,将此迭代器和sc作为参数创建InterruptibleIterator即s对应的迭代器
getPartitions():将数据集合data分成 numSlices个分片,将每个分片封装成一个ParallelCollectionPartition对象
private[spark] class ParallelCollectionPartition[T: ClassTag](
var rddId: Long,
var slice: Int,
var values: Seq[T]
) extends Partitionwith Serializable{
def iterator: Iterator[T] =values.iterator
override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
override def equals(other: Any):Boolean = other match {
case that:ParallelCollectionPartition[_] =>
this.rddId == that.rddId&& this.slice == that.slice
case _ => false
}
override def index: Int = slice
@throws(classOf[IOException])
private def writeObject(out:ObjectOutputStream): Unit = Utils.tryOrIOException {
...
}
@throws(classOf[IOException])
private def readObject(in:ObjectInputStream): Unit = Utils.tryOrIOException {
...
}
}
ParallelCollectionPartition的构造参数分别为rddId(RDD号),slice(分区号),values(分区数据),其中writeObject(out:ObjectOutputStream)为序列化方法,readObject(in: ObjectInputStream)为反序列化方法
3. MapPartitionsRDD
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
override val partitioner = if (preservesPartitioning)firstParent[T].partitioner else None
override def getPartitions:Array[Partition] = firstParent[T].partitions
override def compute(split: Partition,context: TaskContext): Iterator[U] =
f(context, split.index,firstParent[T].iterator(split, context))
override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
MapPartitionsRDD的构造参数为prev(父RDD),f(处理函数,输入taskContext,分区索引,父RDD分区迭代器)即对RDD中的每个分区的每个元素调用f进行处理生成新的元素,从而构成新的RDD
classLocalFileRddPartition(val split: Int, val startLine: Int, val endLine: Int)extendsPartition{
overridedef index: Int = split
@throws(classOf[IOException])
privatedef writeObject(out:ObjectOutputStream): Unit = {
out.defaultWriteObject()
}
@throws(classOf[IOException])
privatedef readObject(in:ObjectInputStream): Unit = {
in.defaultReadObject()
}
}
classLocalFileRdd(sc: SparkContext,fileFullName: String,partitionsNum: Int = 4)extendsRDD[(Int,String)](sc,Nil)
{
val lines =Source.fromFile(fileFullName)
.getLines().toSeq
.zipWithIndex.map(x=>(x._2,x._1))
overridedef getPartitions:Array[Partition] = {
val linesPerPartition= (lines.length+partitionsNum-1)/partitionsNum
return (0 to partitionsNum).map {
split =>
newLocalFileRddPartition(split, split * linesPerPartition, (split + 1) *linesPerPartition)
}.toArray
}
overridedef compute(s:Partition, context: TaskContext): Iterator[(Int,String)] = {
val partition =s.asInstanceOf[LocalFileRddPartition]
lines.iterator.filter(x=>x._1 >=partition.startLine && x._1 < partition.endLine)
}