自定义聚合函数分两个类型,一个是强类型的,需要用DSL语句,另一个就是下面这种
import java.lang
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.Properties
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.{JdbcRDD, RDD}
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object test11 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("haha").getOrCreate()
import spark.implicits._
val ssc: SparkContext = spark.sparkContext
// 生成DF
val userDF: DataFrame = spark.read.json("input/user.json")
// 生成视图
userDF.createOrReplaceTempView("user")
// 生成自定义聚合函数
val udaf = new MyAvg
// 注册自定义聚合函数
spark.udf.register("ageAvg",udaf)
spark.sql("select ageAvg(age) from user").show()
spark.stop()
}
}
class MyAvg extends UserDefinedAggregateFunction {
// 表示聚合函数的输入的数据结构
override def inputSchema: StructType = {
StructType(Array(StructField("age",LongType)))
}
// 表示聚合运算时缓冲区的数据结构
override def bufferSchema: StructType = {
StructType(Array(StructField("total",LongType),(StructField("count",LongType))))
}
// 聚合函数的结果类型
override def dataType: DataType = {
DoubleType
}
// 稳定性
override def deterministic: Boolean = true
// 聚合函数的初始化(缓冲区的初始化)
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 缓冲区内更新
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1L
}
// 缓冲区之间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算聚合函数的结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}