Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节:

  我们都知道一段sql,真正的执行是当你调用它的collect()方法才会执行Spark Job,最后计算得到RDD。

  lazy val toRdd: RDD[Row] = executedPlan.execute()

  Spark Plan基本包含4种操作类型,即BasicOperator基本类型,还有就是Join、Aggregate和Sort这种稍复杂的。

  如图:

  

一、BasicOperator

1.1、Project

  Project 的大致含义是:传入一系列表达式Seq[NamedExpression],给定输入的Row,经过Convert(Expression的计算eval)操作,生成一个新的Row。

  Project的实现是调用其child.execute()方法,然后调用mapPartitions对每一个Partition进行操作。
  这个f函数其实是new了一个MutableProjection,然后循环的对每个partition进行Convert。

 
  1. case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {

  2. override def output = projectList.map(_.toAttribute)

  3. override def execute() = child.execute().mapPartitions { iter => //对每个分区进行f映射

  4. @transient val reusableProjection = new MutableProjection(projectList)

  5. iter.map(reusableProjection)

  6. }

  7. }

  通过观察MutableProjection的定义,可以发现,就是bind references to a schema 和 eval的过程:

  将一个Row转换为另一个已经定义好schema column的Row。
  如果输入的Row已经有Schema了,则传入的Seq[Expression]也会bound到当前的Schema。

 
  1. case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {

  2. def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =

  3. this(expressions.map(BindReferences.bindReference(_, inputSchema))) //bound schema

  4.  
  5. private[this] val exprArray = expressions.toArray

  6. private[this] val mutableRow = new GenericMutableRow(exprArray.size) //新的Row

  7. def currentValue: Row = mutableRow

  8. def apply(input: Row): Row = {

  9. var i = 0

  10. while (i < exprArray.length) {

  11. mutableRow(i) = exprArray(i).eval(input) //根据输入的input,即一个Row,计算生成的Row

  12. i += 1

  13. }

  14. mutableRow //返回新的Row

  15. }

  16. }

1.2、Filter

 Filter的具体实现是传入的condition进行对input row的eval计算,最后返回的是一个Boolean类型,

 如果表达式计算成功,返回true,则这个分区的这条数据就会保存下来,否则会过滤掉。

 
  1. case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

  2. override def output = child.output

  3.  
  4. override def execute() = child.execute().mapPartitions { iter =>

  5. iter.filter(condition.eval(_).asInstanceOf[Boolean]) //计算表达式 eval(input row)

  6. }

  7. }

1.3、Sample

  Sample取样操作其实是调用了child.execute()的结果后,返回的是一个RDD,对这个RDD调用其sample函数,原生方法。

 
  1. case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)

  2. extends UnaryNode

  3. {

  4. override def output = child.output

  5.  
  6. // TODO: How to pick seed?

  7. override def execute() = child.execute().sample(withReplacement, fraction, seed)

  8. }

1.4、Union

  Union操作支持多个子查询的Union,所以传入的child是一个Seq[SparkPlan]

  execute()方法的实现是对其所有的children,每一个进行execute(),即select查询的结果集合RDD。

  通过调用SparkContext的union方法,将所有子查询的结果合并起来。

 
  1. case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {

  2. // TODO: attributes output by union should be distinct for nullability purposes

  3. override def output = children.head.output

  4. override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) //子查询的结果进行union

  5.  
  6. override def otherCopyArgs = sqlContext :: Nil

  7. }

1.5、Limit

  Limit操作在RDD的原生API里也有,即take().

  但是Limit的实现分2种情况:

  第一种是 limit作为结尾的操作符,即select xxx from yyy limit zzz。 并且是被executeCollect调用,则直接在driver里使用take方法。

  第二种是 limit不是作为结尾的操作符,即limit后面还有查询,那么就在每个分区调用limit,最后repartition到一个分区来计算global limit.

 
  1. case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)

  2. extends UnaryNode {

  3. // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:

  4. // partition local limit -> exchange into one partition -> partition local limit again

  5.  
  6. override def otherCopyArgs = sqlContext :: Nil

  7.  
  8. override def output = child.output

  9.  
  10. override def executeCollect() = child.execute().map(_.copy()).take(limit) //直接在driver调用take

  11.  
  12. override def execute() = {

  13. val rdd = child.execute().mapPartitions { iter =>

  14. val mutablePair = new MutablePair[Boolean, Row]()

  15. iter.take(limit).map(row => mutablePair.update(false, row)) //每个分区先计算limit

  16. }

  17. val part = new HashPartitioner(1)

  18. val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) //需要shuffle,来repartition

  19. shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))

  20. shuffled.mapPartitions(_.take(limit).map(_._2)) //最后单独一个partition来take limit

  21. }

  22. }

1.6、TakeOrdered

  TakeOrdered是经过排序后的limit N,一般是用在sort by 操作符后的limit。

  可以简单理解为TopN操作符。

 
  1. case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)

  2. (@transient sqlContext: SQLContext) extends UnaryNode {

  3. override def otherCopyArgs = sqlContext :: Nil

  4.  
  5. override def output = child.output

  6.  
  7. @transient

  8. lazy val ordering = new RowOrdering(sortOrder) //这里是通过RowOrdering来实现排序的

  9.  
  10. override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)

  11.  
  12. // TODO: Terminal split should be implemented differently from non-terminal split.

  13. // TODO: Pick num splits based on |limit|.

  14. override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)

  15. }

1.7、Sort

  Sort也是通过RowOrdering这个类来实现排序的,child.execute()对每个分区进行map,每个分区根据RowOrdering的order来进行排序,生成一个新的有序集合。

  也是通过调用Spark RDD的sorted方法来实现的。

 
  1. case class Sort(

  2. sortOrder: Seq[SortOrder],

  3. global: Boolean,

  4. child: SparkPlan)

  5. extends UnaryNode {

  6. override def requiredChildDistribution =

  7. if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  8.  
  9. @transient

  10. lazy val ordering = new RowOrdering(sortOrder) //排序顺序

  11.  
  12. override def execute() = attachTree(this, "sort") {

  13. // TODO: Optimize sorting operation?

  14. child.execute()

  15. .mapPartitions(

  16. iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, //每个分区调用sorted方法,传入<span style="font-family: Arial, Helvetica, sans-serif;">ordering排序规则,进行排序</span>

  17. preservesPartitioning = true)

  18. }

  19.  
  20. override def output = child.output

  21. }

1.8、ExistingRdd

ExistingRdd是

 
  1. object ExistingRdd {

  2. def convertToCatalyst(a: Any): Any = a match {

  3. case o: Option[_] => o.orNull

  4. case s: Seq[Any] => s.map(convertToCatalyst)

  5. case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)

  6. case other => other

  7. }

  8.  
  9. def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {

  10. data.mapPartitions { iterator =>

  11. if (iterator.isEmpty) {

  12. Iterator.empty

  13. } else {

  14. val bufferedIterator = iterator.buffered

  15. val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)

  16.  
  17. bufferedIterator.map { r =>

  18. var i = 0

  19. while (i < mutableRow.length) {

  20. mutableRow(i) = convertToCatalyst(r.productElement(i))

  21. i += 1

  22. }

  23.  
  24. mutableRow

  25. }

  26. }

  27. }

  28. }

  29.  
  30. def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {

  31. ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))

  32. }

  33. }

 

二、 Join Related Operators

  HashJoin:

  在讲解Join Related Operator之前,有必要了解一下HashJoin这个位于execution包下的joins.scala文件里的trait。

  Join操作主要包含BroadcastHashJoinLeftSemiJoinHashShuffledHashJoin均实现了HashJoin这个trait.

  主要类图如下:

  

  

  HashJoin这个trait的主要成员有:

  buildSide是左连接还是右连接,有一种基准的意思。

  leftKeys是左孩子的expressions, rightKeys是右孩子的expressions。

  left是左孩子物理计划,right是右孩子物理计划。

  buildSideKeyGenerator是一个Projection是根据传入的Row对象来计算buildSide的Expression的。

  streamSideKeyGenerator是一个MutableProjection是根据传入的Row对象来计算streamSide的Expression的。

  这里buildSide如果是left的话,可以理解为buildSide是左表,那么去连接这个左表的右表就是streamSide。

  

  HashJoin关键的操作是joinIterators,简单来说就是join两个表,把每个表看着Iterators[Row].

  方式:

  1、首先遍历buildSide,计算buildKeys然后利用一个HashMap,形成 (buildKeys, Iterators[Row])的格式。

  2、遍历StreamedSide,计算streamedKey,去HashMap里面去匹配key,来进行join

  3、最后生成一个joinRow,这个将2个row对接。

  见代码注释:

 
  1. trait HashJoin {

  2. val leftKeys: Seq[Expression]

  3. val rightKeys: Seq[Expression]

  4. val buildSide: BuildSide

  5. val left: SparkPlan

  6. val right: SparkPlan

  7.  
  8. lazy val (buildPlan, streamedPlan) = buildSide match { //模式匹配,将physical plan封装形成Tuple2,如果是buildLeft,那么就是(left,right),否则是(right,left)

  9. case BuildLeft => (left, right)

  10. case BuildRight => (right, left)

  11. }

  12.  
  13. lazy val (buildKeys, streamedKeys) = buildSide match { //模式匹配,将expression进行封装<span style="font-family: Arial, Helvetica, sans-serif;">Tuple2</span>

  14.  
  15. case BuildLeft => (leftKeys, rightKeys)

  16. case BuildRight => (rightKeys, leftKeys)

  17. }

  18.  
  19. def output = left.output ++ right.output

  20.  
  21. @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) //生成buildSideKey来根据Expression来计算Row返回结果

  22. @transient lazy val streamSideKeyGenerator = //<span style="font-family: Arial, Helvetica, sans-serif;">生成</span><span style="font-family: Arial, Helvetica, sans-serif;">streamSideKeyGenerator</span><span style="font-family: Arial, Helvetica, sans-serif;">来根据Expression来计算Row返回结果</span>

  23. () => new MutableProjection(streamedKeys, streamedPlan.output)

  24.  
  25. def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { //把build表的Iterator[Row]和streamIterator[Row]进行join操作返回Join后的Iterator[Row]

  26. // TODO: Use Spark's HashMap implementation.

  27.  
  28. val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() //匹配主要使用HashMap实现

  29. var currentRow: Row = null

  30.  
  31. // Create a mapping of buildKeys -> rows

  32. while (buildIter.hasNext) { //目前只对build Iterator进行迭代,形成rowKey,Rows,类似wordCount,但是这里不是累加Value,而是Row的集合。

  33. currentRow = buildIter.next()

  34. val rowKey = buildSideKeyGenerator(currentRow) //计算rowKey作为HashMap的key

  35. if(!rowKey.anyNull) {

  36. val existingMatchList = hashTable.get(rowKey)

  37. val matchList = if (existingMatchList == null) {

  38. val newMatchList = new ArrayBuffer[Row]()

  39. hashTable.put(rowKey, newMatchList) //(rowKey, matchedRowList)

  40. newMatchList

  41. } else {

  42. existingMatchList

  43. }

  44. matchList += currentRow.copy() //返回matchList

  45. }

  46. }

  47.  
  48. new Iterator[Row] { //最后用streamedRow的Key来匹配buildSide端的HashMap

  49. private[this] var currentStreamedRow: Row = _

  50. private[this] var currentHashMatches: ArrayBuffer[Row] = _

  51. private[this] var currentMatchPosition: Int = -1

  52.  
  53. // Mutable per row objects.

  54. private[this] val joinRow = new JoinedRow

  55.  
  56. private[this] val joinKeys = streamSideKeyGenerator()

  57.  
  58. override final def hasNext: Boolean =

  59. (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||

  60. (streamIter.hasNext && fetchNext())

  61.  
  62. override final def next() = {

  63. val ret = buildSide match {

  64. case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) //右连接的话,streamedRow放左边,匹配到的key的Row放到右表

  65. case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) //左连接的话,相反。

  66. }

  67. currentMatchPosition += 1

  68. ret

  69. }

  70.  
  71. /**

  72. * Searches the streamed iterator for the next row that has at least one match in hashtable.

  73. *

  74. * @return true if the search is successful, and false if the streamed iterator runs out of

  75. * tuples.

  76. */

  77. private final def fetchNext(): Boolean = {

  78. currentHashMatches = null

  79. currentMatchPosition = -1

  80.  
  81. while (currentHashMatches == null && streamIter.hasNext) {

  82. currentStreamedRow = streamIter.next()

  83. if (!joinKeys(currentStreamedRow).anyNull) {

  84. currentHashMatches = hashTable.get(joinKeys.currentValue) //streamedRow从buildSide里的HashTable里面匹配rowKey

  85. }

  86. }

  87.  
  88. if (currentHashMatches == null) {

  89. false

  90. } else {

  91. currentMatchPosition = 0

  92. true

  93. }

  94. }

  95. }

  96. }

  97. }

joinRow的实现,实现2个Row对接:

实际上就是生成一个新的Array,将2个Array合并。

 
  1. class JoinedRow extends Row {

  2. private[this] var row1: Row = _

  3. private[this] var row2: Row = _

  4. .........

  5. def copy() = {

  6. val totalSize = row1.size + row2.size

  7. val copiedValues = new Array[Any](totalSize)

  8. var i = 0

  9. while(i < totalSize) {

  10. copiedValues(i) = apply(i)

  11. i += 1

  12. }

  13. new GenericRow(copiedValues) //返回一个新的合并后的Row

  14. }

2.1、LeftSemiJoinHash

 left semi join,不多说了,hive早期版本里替代IN和EXISTS 的版本。

 将右表的join keys放到HashSet里,然后遍历左表,查找左表的join key是否能匹配。

 
  1. case class LeftSemiJoinHash(

  2. leftKeys: Seq[Expression],

  3. rightKeys: Seq[Expression],

  4. left: SparkPlan,

  5. right: SparkPlan) extends BinaryNode with HashJoin {

  6.  
  7. val buildSide = BuildRight //buildSide是以右表为基准

  8.  
  9. override def requiredChildDistribution =

  10. ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

  11.  
  12. override def output = left.output

  13.  
  14. def execute() = {

  15. buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => //右表的物理计划执行后生成RDD,利用zipPartitions对Partition进行合并。然后用上述方法实现。

  16. val hashSet = new java.util.HashSet[Row]()

  17. var currentRow: Row = null

  18.  
  19. // Create a Hash set of buildKeys

  20. while (buildIter.hasNext) {

  21. currentRow = buildIter.next()

  22. val rowKey = buildSideKeyGenerator(currentRow)

  23. if(!rowKey.anyNull) {

  24. val keyExists = hashSet.contains(rowKey)

  25. if (!keyExists) {

  26. hashSet.add(rowKey)

  27. }

  28. }

  29. }

  30.  
  31. val joinKeys = streamSideKeyGenerator()

  32. streamIter.filter(current => {

  33. !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)

  34. })

  35. }

  36. }

  37. }

2.2、BroadcastHashJoin

 名约: 广播HashJoin,呵呵。

  是InnerHashJoin的实现。这里用到了concurrent并发里的future,异步的广播buildPlan的表执行后的的RDD。

  如果接收到了广播后的表,那么就用streamedPlan来匹配这个广播的表。

  实现是RDD的mapPartitions和HashJoin里的joinIterators最后生成join的结果。

 
  1. case class BroadcastHashJoin(

  2. leftKeys: Seq[Expression],

  3. rightKeys: Seq[Expression],

  4. buildSide: BuildSide,

  5. left: SparkPlan,

  6. right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {

  7.  
  8. override def otherCopyArgs = sqlContext :: Nil

  9.  
  10. override def outputPartitioning: Partitioning = left.outputPartitioning

  11.  
  12. override def requiredChildDistribution =

  13. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

  14.  
  15. @transient

  16. lazy val broadcastFuture = future { //利用SparkContext广播表

  17. sqlContext.sparkContext.broadcast(buildPlan.executeCollect())

  18. }

  19.  
  20. def execute() = {

  21. val broadcastRelation = Await.result(broadcastFuture, 5.minute)

  22.  
  23. streamedPlan.execute().mapPartitions { streamedIter =>

  24. joinIterators(broadcastRelation.value.iterator, streamedIter) //调用joinIterators对每个分区map

  25. }

  26. }

  27. }

2.3、ShuffleHashJoin

ShuffleHashJoin顾名思义就是需要shuffle数据,outputPartitioning是左孩子的的Partitioning。

会根据这个Partitioning进行shuffle。然后利用SparkContext里的zipPartitions方法对每个分区进行zip。

这里的requiredChildDistribution,的是ClusteredDistribution,这个会在HashPartitioning里面进行匹配。

关于这里面的分区这里不赘述,可以去org.apache.spark.sql.catalyst.plans.physical下的partitioning里面去查看。

 
  1. case class ShuffledHashJoin(

  2. leftKeys: Seq[Expression],

  3. rightKeys: Seq[Expression],

  4. buildSide: BuildSide,

  5. left: SparkPlan,

  6. right: SparkPlan) extends BinaryNode with HashJoin {

  7.  
  8. override def outputPartitioning: Partitioning = left.outputPartitioning

  9.  
  10. override def requiredChildDistribution =

  11. ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

  12.  
  13. def execute() = {

  14. buildPlan.execute().zipPartitions(streamedPlan.execute()) {

  15. (buildIter, streamIter) => joinIterators(buildIter, streamIter)

  16. }

  17. }

  18. }

©️2020 CSDN 皮肤主题: 1024 设计师:上身试试 返回首页