Spark Sql 4/5

4. 用户自定义函数

通过spark.udf功能用户可以自定义函数。

4.1用户自定义UDF函数

Shell
scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+


scala> spark.udf.register("addName", (x:String)=> "Name:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("Select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
|     Name:Michael|null|
|        Name:Andy|  30|
|      Name:Justin|  19|
+-----------------+----+

UDF案例2

需求,有如下数据

Plain Text
id,name,age,height,weight,yanzhi,score
1,a,18,172,120,98,68.8
2,b,28,175,120,97,68.8
3,c,30,180,130,94,88.8
4,d,18,168,110,98,68.8
5,e,26,165,120,98,68.8
6,f,27,182,135,95,89.8
7,g,19,171,122,99,68.8

需要计算每一个人和其他人之间的余弦相似度(特征向量之间的余弦相似度)

 

代码实现:

Scala
package cn.doitedu.sparksql.udf

import cn.doitedu.sparksql.dataframe.SparkUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.UserDefinedFunction

import scala.collection.mutable


/**
  * UDF 案例2 : 用一个自定义函数实现两个向量之间的余弦相似度计算
  */

case class Human(id: Int, name: String, features: Array[Double])

object CosinSimilarity {

  def main(args: Array[String]): Unit = {


    val spark = SparkUtil.getSpark()
    import spark.implicits._
    import spark.sql
    // 加载用户特征数据
    val df = spark.read.option("inferSchema", true).option("header", true).csv("data/features.csv")
    df.show()



    // id,name,age,height,weight,yanzhi,score
    // 将用户特征数据组成一个向量(数组)
    // 方式1:
    df.rdd.map(row => {
      val id = row.getAs[Int]("id")
      val name = row.getAs[String]("name")
      val age = row.getAs[Double]("age")
      val height = row.getAs[Double]("height")
      val weight = row.getAs[Double]("weight")
      val yanzhi = row.getAs[Double]("yanzhi")
      val score = row.getAs[Double]("score")

      (id, name, Array(age, height, weight, yanzhi, score))
    }).toDF("id", "name", "features")

    // 方式2:
    df.rdd.map({
      case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double)
      => (id, name, Array(age, height, weight, yanzhi, score))
    })
      .toDF("id", "name", "features")


    // 方式3: 直接利用sql中的函数array来生成一个数组
    df.selectExpr("id", "name", "array(age,height,weight,yanzhi,score) as features")
    import org.apache.spark.sql.functions._
    df.select('id, 'name, array('age, 'height, 'weight, 'yanzhi, 'score) as "features")

    // 方式4:返回case class
    val features = df.rdd.map({
      case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double)
      => Human(id, name, Array(age, height, weight, yanzhi, score))
    })
      .toDF()

    // 将表自己和自己join,得到每个人和其他所有人的连接行
    val joined = features.join(features.toDF("bid","bname","bfeatures"),'id < 'bid)
    joined.show(100,false)

    // 定义一个计算余弦相似度的函数
    // val cosinSim = (f1:Array[Double],f2:Array[Double])=>{ /* 余弦相似度 */ }
    // 开根号的api:  Math.pow(4.0,0.5)
    val cosinSim = (f1:mutable.WrappedArray[Double], f2:mutable.WrappedArray[Double])=>{

      val fenmu1 = Math.pow(f1.map(Math.pow(_,2)).sum,0.5)
      val fenmu2 = Math.pow(f2.map(Math.pow(_,2)).sum,0.5)

      val fenzi = f1.zip(f2).map(tp=>tp._1*tp._2).sum

      fenzi/(fenmu1*fenmu2)
    }

    // 注册到sql引擎:  spark.udf.register("cosin_sim",consinSim)
    spark.udf.register("cos_sim",cosinSim)
    joined.createTempView("temp")

    // 然后在这个表上计算两人之间的余弦相似度
    sql("select id,bid,cos_sim(features,bfeatures) as cos_similary from temp").show()

    // 可以自定义函数简单包装一下,就成为一个能生成column结果的dsl风格函数了
    val cossim2: UserDefinedFunction = udf(cosinSim)
    joined.select('id,'bid,cossim2('features,'bfeatures) as "cos_sim").show()

    spark.close()
  }
}

4.2用户自定义聚合函数UDAF

弱类型的DataFrame和强类型的Dataset都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。

除此之外,用户可以设定自己的自定义UDAF聚合函数。

UDAF的编程模板:

/**

 * @date: 2019/10/12

 * @site: www.doitedu.cn

 * @author: hunter.d 涛哥

 * @qq: 657270652

 * @description:

  *    用户自定义UDAF入门示例:求薪资的平均值

 */

object MyAvgUDAF extends UserDefinedAggregateFunction{

  // 函数输入的字段schema(字段名-字段类型)

  override def inputSchema: StructType = ???

  // 聚合过程中,用于存储局部聚合结果的schema

  // 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和)

  override def bufferSchema: StructType = ???

  // 函数的最终返回结果数据类型

  override def dataType: DataType = ???

  // 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true

  override def deterministic: Boolean = true

  // 对局部聚合缓存的初始化方法

  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  // 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  // 全局聚合:将多个局部缓存中的数据,聚合成一个缓存

  // 比如:薪资和薪资累加,人数和人数累加

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  // 最终输出

  // 比如:从全局缓存中取薪资总和/人数总和

  override def evaluate(buffer: Row): Any = ???

}

核心要义:

聚合是分步骤进行: 先局部聚合,再全局聚合

局部聚合(update)的结果是保存在一个局部buffer中的

全局聚合(merge)就是将多个局部buffer再聚合成一个buffer

最后通过evaluate将全局聚合的buffer中的数据做一个运算得出你要的结果

如下图所示:

 

4.2.1弱类型用户自定义聚合函数UDAF

(1)需求说明

示例数据:

+---+----------------+------+---------+------+----------+

| id|    name        | sales|discount |state |  saleDate|

+---+----------------+------+---------+------+----------+

|  1|       Widget Co|1000.0|      0.0|    AZ|2014-01-01|

|  2|   Acme Widgets |2000.0|    500.0|    CA|2014-02-01|

|  3|        Widgetry|1000.0|    200.0|    CA|2015-01-11|

|  4|   Widgets R Us |2000.0|      0.0|    CA|2015-02-19|

|  5|Ye Olde Widgete |3000.0|      0.0|    MA|2015-02-28|

+---+---------------+------+--------+-----+-------------+

需求:计算x年份的同比上一年份的总销售增长率;比如2015 vs 2014的同比增长

显然,没有任何一个内置聚合函数可以完成上述需求;

可以多写一些sql逻辑来实现,但如果能自定义一个聚合函数,当然更方便高效!

Select yearOnyear(saleDate,sales)  from t

(2)自定义UDAF实现销售额同比计算

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。

自定义UDAF的代码骨架如下:

class UdfMy extends UserDefinedAggregateFunction{

  override def inputSchema: StructType = ???

  override def bufferSchema: StructType = ???

  override def dataType: DataType = ???

  override def deterministic: Boolean = ???

  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  override def evaluate(buffer: Row): Any = ???

}

完整实现代码如下:

/**

  * 工具类

  * @param startDate

  * @param endDate

  */

case class DateRange(startDate: Timestamp, endDate: Timestamp) {

  def contain(targetDate: Date): Boolean = {

    targetDate.before(endDate) && targetDate.after(startDate)

  }

}

/**

 * @date: 2019/10/10

 * @site: www.doitedu.cn

 * @author: hunter.d 涛哥

 * @qq: 657270652

 * @description: 自定义UDAF实现年份销售额同比增长计算

 */

class YearOnYearBasis(current: DateRange) extends  UserDefinedAggregateFunction{

  // 聚合函数输入参数的数据类型

  override def inputSchema: StructType = {

    StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)

  }

  // 聚合缓冲区中值得数据类型

  override def bufferSchema: StructType = {

    StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)

  }

  // 返回值的数据类型

  override def dataType: DataType = DoubleType

  // 对于相同的输入是否一直返回相同的输出。

  override def deterministic: Boolean = true

  // 初始化

  override def initialize(buffer: MutableAggregationBuffer): Unit = {

    buffer.update(0, 0.0)

    buffer.update(1, 0.0)

  }

  // 相同Execute间的数据合并。

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =  {

    if (current.contain(input.getAs[Date](1))) {

      buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)

    }

    val previous = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))

    if (previous.contain(input.getAs[Date](1))) {

      buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)

    }

  }

  // 不同Execute间的数据合并

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)

    buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)

  }

  // 计算最终结果

  override def evaluate(buffer: Row): Any = {

    if (buffer.getDouble(1) == 0.0)

      0.0

    else

      (buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100

  }

  def subtractOneYear(d:Timestamp):Timestamp={

    Timestamp.valueOf(d.toLocalDateTime.minusYears(1))

  }

}

(3)补充示例:自定义UDAF实现平均薪资计算

下面展示一个求平均工资的自定义聚合函数。

Scala
package cn.doitedu.sparksql.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

/**
  * @description:
  * 用户自定义UDAF入门示例:求薪资的平均值
  */
object MyAvgUDAF extends UserDefinedAggregateFunction {

  // 函数输入的字段schema(字段名-字段类型)
  override def inputSchema: StructType = StructType(Seq(StructField("salary", DataTypes.DoubleType)))

  // 聚合过程中,用于存储局部聚合结果的schema
  // 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和)
  override def bufferSchema: StructType = StructType(Seq(
    StructField("sum", DataTypes.DoubleType),
    StructField("cnts", DataTypes.LongType)

  ))

  // 函数的最终返回结果数据类型
  override def dataType: DataType = DataTypes.DoubleType

  // 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true
  override def deterministic: Boolean = true

  // 对局部聚合缓存的初始化方法
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0.0)
    buffer.update(1, 0L)
  }

  // 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    // 从输入中获取那个人的薪资,加到buffer的第一个字段上
    buffer.update(0, buffer.getDouble(0) + input.getDouble(0))

    // 给buffer的第2个字段加1
    buffer.update(1, buffer.getLong(1) + 1)

  }

  // 全局聚合:将多个局部缓存中的数据,聚合成一个缓存
  // 比如:薪资和薪资累加,人数和人数累加
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    // 把两个buffer的字段1(薪资和)累加到一起,并更新回buffer1
    buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))

    // 更新人数
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))

  }

  // 最终输出
  // 比如:从全局缓存中取薪资总和/人数总和
  override def evaluate(buffer: Row): Any = {

    if (buffer.getLong(1) != 0)
      buffer.getDouble(0) / buffer.getLong(1)
    else
      0.0

  }
}

4.2.2强类型用户自定义聚合函数

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资

Scala
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession

// 既然是强类型,可能有case类
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Employee, Average, Double] {
    // 定义一个数据结构,保存工资总数和工资总个数,初始都为0
    def zero: Average = Average(0L, 0L)
    // Combine two values to produce a new value. For performance, the function may modify `buffer`
    // and return it instead of constructing a new object
    def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
    }
    // 聚合不同execute的结果
    def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
    }
    // 计算输出
    def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
    // 设定之间值类型的编码器,要转换成case类
    // Encoders.product是进行scala元组和case类转换的编码器
    def bufferEncoder: Encoder[Average] = Encoders.product
    // 设定最终输出值的编码器
    def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    import spark.implicits._
​    
    val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
    ds.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+
​    
    // Convert the function to a `TypedColumn` and give it a name
    val averageSalary = MyAverage.toColumn.name("average_salary")
    val result = ds.select(averageSalary)
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
}

5. Spark SQL 的运行原理

正常的 SQL 执行先会经过 SQL Parser 解析 SQL,然后经过 Catalyst 优化器处理,最后到 Spark 执行。而 Catalyst 的过程又分为很多个过程,其中包括:

  • Analysis:主要利用 Catalog 信息将 Unresolved Logical Plan 解析成 Analyzed logical plan;
  • Logical Optimizations:利用一些 Rule (规则)将 Analyzed logical plan 解析成 Optimized Logical Plan;
  • Physical Planning:前面的 logical plan 不能被 Spark 执行,而这个过程是把 logical plan 转换成多个 physical plans,然后利用代价模型(cost model)选择最佳的 physical plan;
  • Code Generation:这个过程会把 SQL逻辑生成Java字节码。

所以整个 SQL 的执行过程可以使用下图表示:

 

其中蓝色部分就是 Catalyst 优化器处理的部分,也是本章主要讲解的内容。

5.1 元数据管理SessionCatalog

SessionCatalog 主要用于各种函数资源信息和元数据信息(数据库、数据表、数据视图、数据分区与函数等)的统一管理。

创建临时表或者视图,其实是往SessionCatalog注册;

Analyzer在进行逻辑计划元数据绑定时,也是从catalog中获取元数据;

5.2 SQL解析成逻辑执行计划

当调用SparkSession的sql或者SQLContext的sql方法,就会使用SparkSqlParser进行SQL解析。

Spark 2.0.0开始引入了第三方语法解析器工具 ANTLR,对 SQL 进行词法分析并构建语法树。

(Antlr 是一款强大的语法生成器工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件,是当前 Java 语言中使用最为广泛的语法生成器工具,我们常见的大数据 SQL 解析都用到了这个工具,包括 Hive、Cassandra、Phoenix、Pig 以及 presto 等)目前最新版本的 Spark 使用的是 ANTLR4)

它分为2个步骤来生成Unresolved LogicalPlan:

  • 词法分析(SqlBaseLexer):Lexical Analysis,负责将token分组成符号类
  • 语法分析(SqlBaseParser):构建一棵分析树(parse tree)或者抽象语法树AST(abstract syntax tree)

Scala
/**
 * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
 * TableIdentifier.
 */
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
  import ParserUtils._

  def this() = this(new SQLConf())

  protected def typedVisit[T](ctx: ParseTree): T = {
  
  ...
  }
}

具体来说,Spark 基于presto的语法文件定义了Spark SQL语法文件SqlBase.g4

(路径 spark-2.4.3\sql\catalyst\src\main\antlr4\org\apache\spark\sql\catalyst\parser\SqlBase.g4)

这个文件定义了 Spark SQL 支持的 SQL 语法。

 

如果我们需要自定义新的语法,需要在这个文件定义好相关语法。然后使用 ANTLR4 对 SqlBase.g4 文件自动解析生成几个 Java 类,其中就包含重要的词法分析器 SqlBaseLexer.java 和语法分析器SqlBaseParser.java。运行上面的 SQL 会使用 SqlBaseLexer 来解析关键词以及各种标识符等;然后使用 SqlBaseParser 来构建语法树。

下面以一条简单的 SQL 为例进行分析

SQL
SELECT sum(v)
    FROM (
      SELECT
        t1.id,
    1 + 2 + t1.value AS v
    FROM t1 JOIN t2
      WHERE
    t1.id = t2.id AND
    t1.cid = 1 AND
    t1.did = t1.cid + 1 AND
      t2.id > 5) o

整个过程就类似于下图。

 

生成语法树之后,使用 AstBuilder 将语法树转换成 LogicalPlan,这个 LogicalPlan 也被称为 Unresolved LogicalPlan。解析后的逻辑计划如下:

Plain Text
== Parsed Logical Plan ==
'Project [unresolvedalias('sum('v), None)]
+- 'SubqueryAlias `doitedu_stu`
   +- 'Project ['t1.id, ((1 + 2) + 't1.value) AS v#16]
      +- 'Filter ((('t1.id = 't2.id) && ('t1.cid = 1)) && (('t1.did = ('t1.cid + 1)) && ('t2.id > 5)))
         +- 'Join Inner
            :- 'UnresolvedRelation `t1`
            +- 'UnresolvedRelation `t2`

图片表示如下:

 

Unresolved LogicalPlan 是从下往上看的,t1 和 t2 两张表被生成了 UnresolvedRelation,过滤的条件、选择的列以及聚合字段都知道了。

Unresolved LogicalPlan 仅仅是一种数据结构,不包含任何数据信息,比如不知道数据源、数据类型,不同的列来自于哪张表等。

5.3 Analyzer绑定逻辑计划

Analyzer 阶段会使用事先定义好的 Rule 以及 SessionCatalog 等信息对 Unresolved LogicalPlan 进行元数据绑定。

Scala
/**
 * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
 * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
 */
class Analyzer(
    catalog: SessionCatalog,
    conf: SQLConf,
    maxIterations: Int)
  extends RuleExecutor[LogicalPlan] with CheckAnalysis {


class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) {
  val astBuilder = new SparkSqlAstBuilder(conf)


   override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    astBuilder.visitSingleStatement(parser.singleStatement()) match {
      case plan: LogicalPlan => plan
      case _ =>
        val position = Origin(None, None)
        throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
    }
  }

Rule 是定义在 Analyzer 里面的,具体如下:
lazy val batches: Seq[Batch] = Seq(
    Batch("Hints", fixedPoint,
      new ResolveHints.ResolveBroadcastHints(conf),
      ResolveHints.ResolveCoalesceHints,
      ResolveHints.RemoveAllHints),
    Batch("Simple Sanity Check", Once,
      LookupFunctions),
    Batch("Substitution", fixedPoint,
      CTESubstitution,
      WindowsSubstitution,
      EliminateUnions,
      new SubstituteUnresolvedOrdinals(conf)),
    Batch("Resolution", fixedPoint,
      ResolveTableValuedFunctions ::                    //解析表的函数
      ResolveRelations ::                               //解析表或视图
      ResolveReferences ::                              //解析列
      ResolveCreateNamedStruct ::
      ResolveDeserializer ::                            //解析反序列化操作类
      ResolveNewInstance ::
      ResolveUpCast ::                                  //解析类型转换
      ResolveGroupingAnalytics ::
      ResolvePivot ::
      ResolveOrdinalInOrderByAndGroupBy ::
      ResolveAggAliasInGroupBy ::
      ResolveMissingReferences ::
      ExtractGenerator ::
      ResolveGenerate ::
      ResolveFunctions ::                               //解析函数
      ResolveAliases ::                                 //解析表别名
      ResolveSubquery ::                                //解析子查询
      ResolveSubqueryColumnAliases ::
      ResolveWindowOrder ::
      ResolveWindowFrame ::
      ResolveNaturalAndUsingJoin ::
      ResolveOutputRelation ::
      ExtractWindowExpressions ::
      GlobalAggregates ::
      ResolveAggregateFunctions ::
      TimeWindowing ::
      ResolveInlineTables(conf) ::
      ResolveHigherOrderFunctions(catalog) ::
      ResolveLambdaVariables(conf) ::
      ResolveTimeZone(conf) ::
      ResolveRandomSeed ::
      TypeCoercion.typeCoercionRules(conf) ++
      extendedResolutionRules : _*),
    Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
    Batch("View", Once,
      AliasViewChild(conf)),
    Batch("Nondeterministic", Once,
      PullOutNondeterministic),
    Batch("UDF", Once,
      HandleNullInputsForUDF),
    Batch("FixNullability", Once,
      FixNullability),
    Batch("Subquery", Once,
      UpdateOuterReferences),
    Batch("Cleanup", fixedPoint,
      CleanupAliases)
)

从上面代码可以看出,多个性质类似的 Rule 组成一个 Batch;而多个 Batch 构成一个 batches。这些 batches 会由 RuleExecutor 执行,先按一个一个 Batch 顺序执行,然后对 Batch 里面的每个 Rule 顺序执行。每个 Batch 会执行一次(Once)或多次(FixedPoint,由spark.sql.optimizer.maxIterations 参数决定),执行过程如下:

 

5.4 Optimizer优化逻辑计划

优化器也是会定义一套Rules,利用这些Rule对逻辑计划和Exepression进行迭代处理,从而使得树的节点进行合并和优化

 

 

在前文的绑定逻辑计划阶段对 Unresolved LogicalPlan 进行相关 transform 操作得到了 Analyzed Logical Plan,这个 Analyzed Logical Plan 是可以直接转换成 Physical Plan 然后在spark中执行。但是如果直接这么弄的话,得到的 Physical Plan 很可能不是最优的,因为在实际应用中,很多低效的写法会带来执行效率的问题,需要进一步对Analyzed Logical Plan 进行处理,得到更优的逻辑算子树。于是,针对SQL 逻辑算子树的优化器 Optimizer 应运而生。

这个阶段的优化器主要是基于规则的(Rule-based Optimizer,简称 RBO),而绝大部分的规则都是启发式规则,也就是基于直观或经验而得出的规则,比如列裁剪(过滤掉查询不需要使用到的列)、谓词下推(将过滤尽可能地下沉到数据源端)、常量累加(比如 1 + 2 这种事先计算好) 以及常量替换(比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8)等等。

与绑定逻辑计划阶段类似,这个阶段所有的规则也是实现 Rule 抽象类,多个规则组成一个 Batch,多个 Batch 组成一个 batches,同样也是在 RuleExecutor 中进行执行。

核心源码骨架如下列截图所示:      

 

 

 

 

 

那么针对前文的 SQL 语句,这个过程都会执行哪些优化呢?下文举例说明。

5.4.1谓词下推

谓词下推在 Spark SQL 是由 PushDownPredicate 实现的,这个过程主要将过滤条件尽可能地下推到底层,最好是数据源。上面介绍的 SQL,使用谓词下推优化得到的逻辑计划如下:

 

从上图可以看出,谓词下推将 Filter 算子直接下推到 Join 之前了(注意,上图是从下往上看的)。也就是在扫描 t1 表的时候会先使用 ((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 50000)) && isnotnull(id#0) 过滤条件过滤出满足条件的数据;同时在扫描 t2 表的时候会先使用 isnotnull(id#8) && (id#8 > 50000) 过滤条件过滤出满足条件的数据。经过这样的操作,可以大大减少 Join 算子处理的数据量,从而加快计算速度。

5.4.2列裁剪

列裁剪在 Spark SQL 是由 ColumnPruning 实现的。因为我们查询的表可能有很多个字段,但是每次查询我们很大可能不需要扫描出所有的字段,这个时候利用列裁剪可以把那些查询不需要的字段过滤掉,使得扫描的数据量减少。所以针对我们上面介绍的 SQL,使用列裁剪优化得到的逻辑计划如下:

 

从上图可以看出,经过列裁剪后,t1 表只需要查询 id 和 value 两个字段;t2 表只需要查询 id 字段。这样减少了数据的传输,而且如果底层的文件格式为列存(比如 Parquet),可以大大提高数据的扫描速度的。

 

5.4.3常量替换

常量替换在 Spark SQL 是由 ConstantPropagation 实现的。也就是将变量替换成常量,比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8。这个看起来好像没什么的,但是如果扫描的行数非常多可以减少很多的计算时间的开销的。经过这个优化,得到的逻辑计划如下:

我们的查询中有 t1.cid = 1 AND t1.did = t1.cid + 1 查询语句,从里面可以看出 t1.cid 其实已经是确定的值了,所以我们完全可以使用它计算出 t1.did。

5.4.4常量累加

常量累加在 Spark SQL 是由 ConstantFolding 实现的。这个和常量替换类似,也是在这个阶段把一些常量表达式事先计算好。这个看起来改动的不大,但是在数据量非常大的时候可以减少大量的计算,减少 CPU 等资源的使用。经过这个优化,得到的逻辑计划如下:

 

经过上面四个步骤的优化之后,得到的优化之后的逻辑计划为:

Plain Text
== Optimized Logical Plan ==
Aggregate [sum(cast(v#16 as bigint)) AS sum(v)#22L]
+- Project [(3 + value#1) AS v#16]
   +- Join Inner, (id#0 = id#8)
      :- Project [id#0, value#1]
      :  +- Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0))
      :     +- Relation[id#0,value#1,cid#2,did#3] csv
      +- Project [id#8]
         +- Filter (isnotnull(id#8) && (id#8 > 5))
            +- Relation[id#8,value#9,cid#10,did#11] csv

对应的图如下:

 

到这里,优化逻辑计划阶段就算完成了。另外,Spark 内置提供了多达70个优化 Rule,详情请参见

https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L59

5.5使用SparkPlanner生成物理计划

SparkSpanner使用Planning Strategies,对优化后的逻辑计划进行转换,生成可以执行的物理计划SparkPlan.

Scala
/**
 * 将逻辑计划转成物理计划的抽象类.
 * 各实现类通过各种GenericStrategy来生成各种可行的待选物理计划.
 * 如一个策略无法对逻辑计划树的所有操作转换,则会调用[GenericStrategy#planLater planLater]], 来获得       一个“占位符”对象暂时填充;之后由[[collectPlaceholders collected]]收集并使用其他策略进行转换

 * TODO: 目前为止,永远只生成一个物理计划
 *       后续迭代中会对“多计划”予以实现
 */
abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
  /** A list of execution strategies that can be used by the planner */
  def strategies: Seq[GenericStrategy[PhysicalPlan]]

  def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = {
    // 显然,此处还有大量工作需要做,可依然...

    // 收集所有可选的物理计划.
    val candidates = strategies.iterator.flatMap(_(plan))

abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
  self: SparkPlanner =>

  /**
   * Plans special cases of limit operators.
   */
  object SpecialLimits extends Strategy {
      
   
class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val experimentalMethods: ExperimentalMethods)
  extends SparkStrategies {      

逻辑计划翻译成物理计划时,使用的是策略(Strategy);

前面介绍的逻辑计划绑定和优化经过 Transformations 动作之后,树的类型并没有改变,

Logical Plan转化成物理计划后,树的类型改变了,由 Logical Plan 转换成 Physical Plan 了。

一个逻辑计划(Logical Plan)经过一系列的策略处理之后,得到多个物理计划(Physical Plans),物理计划在 Spark 是由 SparkPlan 实现的。

多个物理计划经过代价模型(Cost Model)得到选择后的物理计划(Selected Physical Plan),整个过程如下所示:

 

Cost Model 对应的就是基于代价的优化(Cost-based Optimizations,CBO,主要由华为的大佬们实现的,详见 SPARK-16026 ),核心思想是计算每个物理计划的代价,然后得到最优的物理计划。目前,这一部分并没有实现,直接返回多个物理计划列表的第一个作为最优的物理计划,如下:

Scala
lazy val sparkPlan: SparkPlan = {
    SparkSession.setActiveSession(sparkSession)
    // TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
    //       but we will implement to choose the best plan.
    planner.plan(ReturnAnswer(optimizedPlan)).next()
}

而 SPARK-16026 引入的 CBO 优化主要是在前面介绍的优化逻辑计划阶段 - Optimizer 阶段进行的,对应的 Rule 为 CostBasedJoinReorder,并且默认是关闭的,需要通过 spark.sql.cbo.enabled 或 spark.sql.cbo.joinReorder.enabled 参数开启。

所以到了这个节点,最后得到的物理计划如下:

Plain Text
== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[sum(cast(v#16 as bigint))], output=[sum(v)#22L])
+- Exchange SinglePartition
   +- *(2) HashAggregate(keys=[], functions=[partial_sum(cast(v#16 as bigint))], output=[sum#24L])
      +- *(2) Project [(3 + value#1) AS v#16]
         +- *(2) BroadcastHashJoin [id#0], [id#8], Inner, BuildRight
            :- *(2) Project [id#0, value#1]
            :  +- *(2) Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0))
            :     +- *(2) FileScan csv [id#0,value#1,cid#2,did#3] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t1.csv], PartitionFilters: [], PushedFilters: [IsNotNull(cid), IsNotNull(did), EqualTo(cid,1), EqualTo(did,2), GreaterThan(id,5), IsNotNull(id)], ReadSchema: struct<id:int,value:int,cid:int,did:int>
            +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)))
               +- *(1) Project [id#8]
                  +- *(1) Filter (isnotnull(id#8) && (id#8 > 5))
                     +- *(1) FileScan csv [id#8] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t2.csv], PartitionFilters: [], PushedFilters: [IsNotNull(id), GreaterThan(id,5)], ReadSchema: struct<id:int>

从上面的结果可以看出,物理计划阶段已经知道数据源是从 csv 文件里面读取了,也知道文件的路径,数据类型等。而且在读取文件的时候,直接将过滤条件(PushedFilters)加进去了。

同时,这个 Join 变成了 BroadcastHashJoin,也就是将 t2 表的数据 Broadcast 到 t1 表所在的节点。图表示如下:

 

到这里, Physical Plan 就完全生成了。

5.6从物理执行计划获取inputRdd执行

从物理计划上,获取inputRdd

从物理计划上,生成全阶段代码,并编译反射出迭代器newBiIterator的Clazz

 [真名:BufferedRowIterator]

然后将inputRDD做一个transformation得到最终要执行的rdd

Scala
inputRdd.mapPartitionsWithIndex((index,iter)=>{
   new newBiIterator(){
     hasNext(){
         iter.hasNext
}
     next(){
         processNext(iter.next())
}
}
})

然后,对最后返回的rdd,执行你所需要的行动算子
rdd.collect().foreach(println)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值