像Hive一样自定义聚合函数
弱类型自定义聚合函数
继承UserDefinedAggregateFunction 来实现,面向DataFrame
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
object MyAverage extends UserDefinedAggregateFunction {
// 聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// 聚合缓存的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// 返回值的数据类型
def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出
def deterministic: Boolean = true
// 初始化聚合缓存,就像给自己另外新增的一行,提供给一些使用索引来取值的标准函数
//(比如get(),getBoolean()等)使用,以便更新这个缓存的数据
// 缓存内部的数组和map都不可变
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 使用新输入input更新聚合缓存buffer
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 合并两个聚合缓存,并更新到buffer1里去(就是合并分区)
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算终值
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
// 注册函数以便访问
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// 面向DataFrame
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
强类型自定义聚合函数
继承Aggregator来实现,面向Datasets
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
// 行类型
case class Employee(name: String, salary: Long)
// 内部缓存类型
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
// 聚合的初始值,比如满足:任何 b + zero = b
def zero: Average = Average(0L, 0L)
// 合并两个值。用新值直接更新buffer,并返回buffer本身,而不是重新new一个
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// 合并两个聚合缓存(就是合并分区)
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 计算终值
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// 定义内部缓存类型的编码器(前文提到的编码器,用于spark运算中的内部序列化和反序列化,不必深究)
def bufferEncoder: Encoder[Average] = Encoders.product
// 定义输出值类型的编码器
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// 为函数计算定一个别名
val averageSalary = MyAverage.toColumn.name("average_salary")
// 面向 Dataset DSL风格
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+