Spark 自定义聚合函数(UDAF)UserDefinedAggregateFunction 原理用法示例源码分析
原理
UserDefinedAggregateFunction
是 Spark SQL 中用于实现用户自定义聚合函数(UDAF)的抽象类。通过继承该类并实现其中的方法,可以创建自定义的聚合函数,并在 Spark SQL 中使用。
UserDefinedAggregateFunction
的原理是基于 Spark SQL 的聚合操作流程。当一个 UDAF 被应用到 DataFrame 上时,Spark SQL 会将 UDAF 转化为一个 AggregateExpression
对象,其中包含了对应的 ScalaUDAF
实例和聚合操作类型。然后,Spark SQL 会对数据进行分组、聚合等操作,并调用 UDAF 中的方法来执行具体的聚合逻辑。
在具体实现中,UserDefinedAggregateFunction
提供了一系列方法,如 inputSchema
、bufferSchema
、dataType
等,用于定义输入参数的数据类型、缓冲区中值的数据类型以及返回值的数据类型。同时,它还提供了 initialize
、update
、merge
和 evaluate
方法,用于初始化聚合缓冲区、更新缓冲区、合并缓冲区以及计算最终结果。此外,UserDefinedAggregateFunction
还提供了 apply
和 distinct
方法,用于创建 Column
对象,方便在 DataFrame 中使用自定义聚合函数。
总的来说,UserDefinedAggregateFunction
通过定义一系列方法,使得用户可以灵活地实现自定义的聚合逻辑,并将其应用到 Spark SQL 的聚合操作中。通过这种方式,用户可以扩展 Spark SQL 中的聚合能力,满足特定的业务需求。
用法
方法名 | 描述 |
---|---|
inputSchema | 返回聚合函数的输入参数的数据类型的 StructType 。 |
bufferSchema | 返回聚合缓冲区中值的数据类型的 StructType 。 |
dataType | 返回聚合函数的返回值的数据类型。 |
deterministic | 返回布尔值,指示此函数是否是确定性的。 |
initialize(buffer) | 初始化给定的聚合缓冲区。 |
update(buffer, input) | 使用新的输入数据更新聚合缓冲区。 |
merge(buffer1, buffer2) | 合并两个聚合缓冲区。 |
evaluate(buffer) | 根据给定的聚合缓冲区计算最终结果。 |
apply(exprs) | 使用给定的 Column 参数创建一个 Column 对象来调用 UDAF。 |
distinct(exprs) | 使用给定的不同值的 Column 参数创建一个 Column 对象来调用 UDAF。 |
update(i, value) | 更新可变聚合缓冲区的第 i 个值。 |
示例
package org.example.spark
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object AverageVecDemo {
// 创建自定义聚合函数
class MyAverage extends UserDefinedAggregateFunction {
// 输入参数的数据类型
def inputSchema: StructType = new StructType().add("value", DoubleType)
// 聚合缓冲区中值的数据类型
def bufferSchema: StructType = new StructType()
.add("sum", DoubleType)
.add("count", LongType)
// 返回值的数据类型
def dataType: DataType = DoubleType
// 是否是确定性的
def deterministic: Boolean = true
// 初始化聚合缓冲区
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0 // sum
buffer(1) = 0L // count
}
// 更新聚合缓冲区
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val value = input.getDouble(0)
buffer(0) = buffer.getDouble(0) + value
buffer(1) = buffer.getLong(1) + 1
}
}
// 合并两个聚合缓冲区
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getLong(1)
}
}
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[*]")
.getOrCreate()
import spark.implicits._
// 创建一个 DataFrame
val data = Seq(1.0, 2.0, 3.0, 4.0, 5.0).toDF("value")
// 注册自定义聚合函数
spark.udf.register("myAverage", new MyAverage)
// 使用自定义聚合函数进行聚合操作
val result = data.selectExpr("myAverage(value) as average")
result.show()
spark.stop()
}
}
//+-------+
//|average|
//+-------+
//| 3.0|
//+-------+
这个示例中,我们创建了一个自定义聚合函数 MyAverage
,用于计算输入数据列的平均值。然后,我们将该函数注册到 Spark 的 UDF(用户定义函数)中,并在 DataFrame 中使用 selectExpr
方法调用它进行聚合操作。最后,我们展示了聚合结果。
源码
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.types._
/**
* 实现用户自定义聚合函数(UDAF)的基类。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class UserDefinedAggregateFunction extends Serializable {
/**
* `StructType` 表示此聚合函数的输入参数的数据类型。
* 例如,如果一个[[UserDefinedAggregateFunction]]期望两个输入参数,
* 分别是`DoubleType`和`LongType`类型,返回的`StructType`将如下所示:
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* 此`StructType`的字段名称仅用于标识对应的输入参数。用户可以选择名称以标识输入参数。
*
* @since 1.5.0
*/
def inputSchema: StructType
/**
* `StructType` 表示聚合缓冲区中值的数据类型。
* 例如,如果一个[[UserDefinedAggregateFunction]]的缓冲区有两个值
* (即两个中间值),分别是`DoubleType`和`LongType`类型,
* 返回的`StructType`将如下所示:
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* 此`StructType`的字段名称仅用于标识对应的缓冲区值。用户可以选择名称以标识输入参数。
*
* @since 1.5.0
*/
def bufferSchema: StructType
/**
* [[UserDefinedAggregateFunction]] 返回值的 `DataType`。
*
* @since 1.5.0
*/
def dataType: DataType
/**
* 如果此函数是确定性的,则返回true,即给定相同的输入,总是返回相同的输出。
*
* @since 1.5.0
*/
def deterministic: Boolean
/**
* 初始化给定的聚合缓冲区,即聚合缓冲区的初始值。
*
* 即应用于两个初始缓冲区的合并函数只应返回初始缓冲区本身,即
* `merge(initialBuffer, initialBuffer)` 应等于 `initialBuffer`。
*
* @since 1.5.0
*/
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* 使用来自`input`的新输入数据更新给定的聚合缓冲区`buffer`。
*
* 每行输入调用一次此方法。
*
* @since 1.5.0
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* 合并两个聚合缓冲区,并将更新后的缓冲区值存储回`buffer1`。
*
* 当我们合并两个部分聚合的数据时,会调用此方法。
*
* @since 1.5.0
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* 根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
*
* @since 1.5.0
*/
def evaluate(buffer: Row): Any
/**
* 使用给定的`Column`s作为输入参数创建此UDAF的`Column`。
*
* @since 1.5.0
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
}
/**
* 使用给定的`Column`s的不同值作为输入参数创建此UDAF的`Column`。
*
* @since 1.5.0
*/
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
}
/**
* 表示可变聚合缓冲区的`Row`。
*
* 不建议在Spark之外扩展它。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** 更新此缓冲区的第i个值。 */
def update(i: Int, value: Any): Unit
}
gregateExpression)
}
}
/**
* 表示可变聚合缓冲区的`Row`。
*
* 不建议在Spark之外扩展它。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** 更新此缓冲区的第i个值。 */
def update(i: Int, value: Any): Unit
}
参考链接
https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html