数据
{"name":"123","age":20}
{"name":"456","age":30}
{"name":"789","age":40}
代码
package com.bfd.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DoubleType, LongType, StructType}
object SparkSQL05_UDAF {
def main(args: Array[String]): Unit = {
//sparkConf
val sparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("sparkSql_01")
//sparkSession
val spark=
SparkSession.builder().config(sparkConf).getOrCreate()
//引入隐式转化
import spark.implicits._
//udaf
val udaf = new MyageFunction
//注册聚合函数
spark.udf.register("avgAge",udaf)
val frame =
spark
.read
.json(
"D:\\development\\code\\2021.01\\scala1\\spark_core\\src\\main\\in\\user.json")
frame.createOrReplaceTempView("user")
spark.sql("select avgAge(age) from user").show()
//释放资源
spark.stop
}
}
//声明用户自定义聚合函数
//1)继承UserDefinedAggregateFunction
//2)实现方法
class MyageFunction extends UserDefinedAggregateFunction{
//函数输入的数据结构
override def inputSchema = {
new StructType().add("age",LongType)
}
//计算时的数据结构
override def bufferSchema = {
new StructType().add("sum",LongType).add("count",LongType)
}
//函数返回的数据类型
override def dataType = DoubleType
//函数是否稳定
override def deterministic = true
//计算之前缓冲区的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//根据查询结果更新到新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0) + input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum累加
buffer1(0) =buffer1.getLong(0)+buffer2.getLong(0)
//count累加
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算
override def evaluate(buffer: Row) = {
buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
}
}
结果