一、弱类型(低版本3.0以下)
定义类继承 UserDefinedAggregateFunction,并重写其中方法
class MyAveragUDAF extends UserDefinedAggregateFunction {
// 聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(Array(StructField("age",IntegerType)))
// 聚合函数缓冲区中值的数据类型(age,count) def bufferSchema: StructType = {
StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
}
// 函数返回值的数据类型
def dataType: DataType = DoubleType
// 稳定性:对于相同的输入是否一直返回相同的输出。
def deterministic: Boolean = true
// 函数缓冲区初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存年龄的总和
buffer(0) = 0L
// 存年龄的个数
buffer(1) = 0L
}
// 更新缓冲区中的数据
def update(buffer: MutableAggregationBuffer,input: Row): Unit = { if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getInt(0) buffer(1) = buffer.getLong(1) + 1
}
}
// 合并缓冲区
def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
。。。
//创建聚合函数
var myAverage = new MyAveragUDAF
//在 spark 中注册聚合函数
spark.udf.register("avgAge",myAverage)
二、强类型(3.0后推荐使用)
- 定义类继承 org.apache.spark.sql.expressions.Aggregator
- 重写类中的方法
低版本:采用DSL风格查询
//输入数据类型
case class User01(username:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)
/**
* 定义类继承 org.apache.spark.sql.expressions.Aggregator
* 重写类中的方法
*/
class MyAveragUDAF1 extends Aggregator[User01,AgeBuffer,Double]{ override def zero: AgeBuffer = {
AgeBuffer(0L,0L)
}
override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
b.sum = b.sum + a.age b.count = b.count + 1 b
}
override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = { b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count b1
}
override def finish(buff: AgeBuffer): Double = { buff.sum.toDouble/buff.count
}
//DataSet 默认额编解码器,用于序列化,固定写法
//自定义类型就是 product 自带类型根据类型选择
override def bufferEncoder: Encoder[AgeBuffer] = { Encoders.product
}
override def outputEncoder: Encoder[Double] = { Encoders.scalaDouble
}
}
。。。
//封装为 DataSet
val ds: Dataset[User01] = df.as[User01]
//创建聚合函数
var myAgeUdaf1 = new MyAveragUDAF1
//将聚合函数转换为查询的列
val col: TypedColumn[User01, Double] = myAgeUdaf1.toColumn
ds.select(col).show()
高版本(3.0):注册函数方式
// TODO 创建 UDAF 函数
val udaf = new MyAvgAgeUDAF
// TODO 注册到 SparkSQL 中
spark.udf.register("avgAge", functions.udaf(udaf))
// TODO 在 SQL 中使用聚合函数
// 定义用户的自定义聚合函数
spark.sql("select avgAge(age) from user").show
// ************************************************** case class Buff( var sum:Long, var cnt:Long )
// totalage, count
class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{ override def zero: Buff = Buff(0,0)
override def reduce(b: Buff, a: Long): Buff = { b.sum += a
b.cnt += 1 b
}
override def merge(b1: Buff, b2: Buff): Buff = { b1.sum += b2.sum
b1.cnt += b2.cnt b1
}
override def finish(reduction: Buff): Double = { reduction.sum.toDouble/reduction.cnt
}
override def bufferEncoder: Encoder[Buff] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble