SparkSQL DF.agg 执行过程解析

本文探讨了SparkSQL中如何通过`DF.agg`方法获取max、avg、min等统计值的过程。从`groupBy`到`agg`方法,再到DataFrame的构建和优化,分析了`unresolvedFunction`如何在Analyzer阶段被解析为`AggregateExpression`。通过Spark的优化流程,解释了为何传入特定表达式即可得到预期结果,并指出这种方式比使用RDD计算更快。
摘要由CSDN通过智能技术生成
在上一篇文章前, 我一直没看懂为什么下面的代码就能得到max或者avg或者min的值:

malePPL.agg(Map("height" -> "max", "sex" -> "count")).show


数据是
身高 性别
这样的一个组合大概有几百万个值

刚开始是使用reducebykey去做计算, 后来发现网上有agg里面直接进行排序获取值的做法, 特地看了一下为什么传进去一个Map(column -> Expression)就能得到想要的结果

首先还是直接进到agg的方法里面:

/**
* (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
* {{{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}}
* @group dfops
* @since 1.3.0
*/
def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)


看到他是执行groupBy返回对象的agg方法, 可以看到groupBy是一个GroupData:

@scala.annotation.varargs
def groupBy(cols: Column*): GroupedData = {
GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
}


GroupedData的agg方法:


def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
strToExpr(expr)(df(colName).expr)
}.toSeq)
}


可以看到他是使用toDF方法构建一个DataFrame, 看一下strToExpr里面其实是做了一个unresolvedFunction:

private[this] def strToExpr(expr: String): (Expression => Expression) = {
val exprToFunc: (Expression => Expression) = {
(inputExpr: Expression) => expr.toLowerCase match {
// We special handle a few cases that have alias that are not in function registry.
case "avg" | "average" | "mean" =>
UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false)
case "stddev" | "std" =>
UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false)
// Also special handle count because we need to take care count(*).
case "count" | "size" =>
// Turn count(*) into count(1)
inputExpr match {
case s: Star => Count(Literal(1)).toAggregateExpression()
case _ => Count(inputExpr).toAggregateExpression()
}
case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false)
}
}
(inputExpr: Expression) => exprToFunc(inputExpr)
}



看一下toDF是怎么写的:

private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
groupingExprs ++ aggExprs
} else {
aggExprs
}

val aliasedAgg = aggregates.map(alias)

groupType match {
case GroupedData.GroupByType =>
DataFrame(
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
DataFrame(
df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}

在groupBy方法里面我们其实可以看到传入的grouptype是GroupedData.GroupByType

所以这里会去执行:
case GroupedData.GroupByType =>
DataFrame(
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))

Aggregate方法继承自UnaryNode, 也就是一个LogicPlan

case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalPlan)
extends UnaryNode {

override lazy val resolved: Boolean = {
val hasWindowExpressions = aggregateExpressions.exists ( _.collect {
case window: WindowExpression => window
}.nonEmpty
)

!expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
}

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}


这个logicplan包含了我们传入的表达式, 比如说hight-> max这样的。 经过这几步后, 一个DataFrame被创建了, 按照之前的那片文章来看, DF会做下面这几步去优化logicplan直到一个可执行的物理计划为止: (包含对unresolvedFunction的优化)
1.通过Sqlparse 转成unresolvedLogicplan
2.通过Analyzer转成 resolvedLogicplan
3.通过optimizer转成 optimzedLogicplan
4.通过sparkplanner转成physicalLogicplan
5.通过prepareForExecution 转成executable logicplan
6.通过toRDD等方法执行executedplan去调用tree的doExecute

既然这样, 那么我们看一下unresolvedFunction是怎么会和max min avg等expression关联起来的, 进入analyzer, 看到SQLContext里面创建Analyzer时候传入了一个registry:

protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()

protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
ExtractPythonUDFs ::
PreInsertCastAndRename ::
(if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)

override val extendedCheckRules = Seq(
datasources.PreWriteCheck(catalog)
)
}


在这个FunctionRegistry里面包含了所有的expression:

object FunctionRegistry {

type FunctionBuilder = Seq[Expression] => Expression

val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
expression[Abs]("abs"),
expression[CreateArray]("array"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[IsNaN]("isnan"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[Coalesce]("nvl"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
expression[CreateNamedStruct]("named_struct"),
expression[Sqrt]("sqrt"),
expression[NaNvl]("nanvl"),

// math functions
expression[Acos]("acos"),
expression[Asin]("asin"),
expression[Atan]("atan"),
expression[Atan2]("atan2"),
expression[Bin]("bin"),
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
expression[Ceil]("ceiling"),
expression[Cos]("cos"),
expression[Cosh]("cosh"),
expression[Conv]("conv"),
expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Factorial]("factorial"),
expression[Hypot]("hypot"),
expression[Hex]("hex"),
expression[Logarithm]("log"),
expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Log2]("log2"),
expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
expression[Pow]("pow"),
expression[Pow]("power"),
expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
expression[Tan]("tan"),
expression[Tanh]("tanh"),
expression[ToDegrees]("degrees"),
expression[ToRadians]("radians"),

// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[First]("first"),
expression[First]("first_value"),
expression[Last]("last"),
expression[Last]("last_value"),
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[Skewness]("skewness"),
expression[Kurtosis]("kurtosis"),

// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
expression[JsonTuple]("json_tuple"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[FormatString]("format_string"),
expression[FormatString]("printf"),
expression[StringRPad]("rpad"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
expression[StringTrimRight]("rtrim"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[SubstringIndex]("substring_index"),
expression[StringTranslate]("translate"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),

// datetime functions
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
expression[CurrentTimestamp]("now"),
expression[DateDiff]("datediff"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
expression[DayOfYear]("dayofyear"),
expression[DayOfMonth]("dayofmonth"),
expression[FromUnixTime]("from_unixtime"),
expression[FromUTCTimestamp]("from_utc_timestamp"),
expression[Hour]("hour"),
expression[LastDay]("last_day"),
expression[Minute]("minute"),
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[ToDate]("to_date"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),

// collection functions
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[ArrayContains]("array_contains"),

// misc functions
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Sha1]("sha"),
expression[Sha1]("sha1"),
expression[Sha2]("sha2"),
expression[SparkPartitionID]("spark_partition_id"),
expression[InputFileName]("input_file_name"),
expression[MonotonicallyIncreasingID]("monotonically_increasing_id")
)



这样当Analyzer在执行execute方法, 对所有的node进行Rule的时候, 有一个Rule叫ResolveFunctions, 下面是analyzer里面定义的batch:

lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic,
ComputeCurrentTime),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)


在ResolveFunctions 是这样定义的:

object ResolveFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
// DISTINCT is not meaningful for a Max or a Min.
case max: Max if isDistinct =>
AggregateExpression(max, Complete, isDistinct = false)
case min: Min if isDistinct =>
AggregateExpression(min, Complete, isDistinct = false)
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
// This function is not an aggregate function, just return the resolved one.
case other => other
}
}
}
}
}


看到这个方法会对所有的expression进行遍历:
registry.lookupFunction(name, children) match{
...
}

如果我们传入的是max或者min, 或者不属于这两者的, 那么直接就能返回aggregateexpression:
AggregateExpression(max, Complete, isDistinct = false)
AggregateExpression(min, Complete, isDistinct = false)
AggregateExpression(agg, Complete, isDistinct)
这样我们传入的max min就被registryFunction里面的expression代替了, 继续通过其他Rule执行来变成resolvedaggreFunction。

可以看到我们定义的max min或者avg其实在构建DataFrame的时候已经在其最总的执行计划里面了, 就不难理解为什么我们这样传入参数就能得到这些结果。

根据测试结果, 传入agg的expression的方法远比rdd计算获取结果快的多。 目前来看, 如果能用agg这样去获取想要的结果, 那么就不要用rdd去进行计算了。

如果有什么不对的地方, 请指正

ps:可以试一下传入的参数不在registryFunction里面的话会由checkAnalysis(resolvedAggregate)这个方法发现及抛出异常
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值