SparkSQL-自定义聚合函数 (实现几何平均数)
->创建SparkSessionparkSession
->创建自定义函数
-1、继承UserDefinedAggregateFunction
-2、重写下面的方法
inputSchema -输入数据的类型
bufferSchema -产生中间结果的数据类型
dataType -最终返回的结果类型
deterministic -确保一致性
initialize -指定初始值
update -每有一条数据参与运算就更新一下中间结果
merge -全局聚合
evaluate -计算最终结果
!:StructField -哪些列,啥类型
->实例化自定义函数,并注册自定义函数(spark.udf.register)
代码:
object Geometric {
def main(args: Array[String]): Unit = {
//创建sparkSession
val sparkSession: SparkSession = SparkSession.builder().appName("Geometric").master("local[*]").getOrCreate()
//造数据 1~10
val range: Dataset[lang.Long] = sparkSession.range(1, 11)
//实例化Geom类
val geomean = new Geom
//注册视图
range.createTempView("v_range")
//注册自定义函数
sparkSession.udf.register("ge", geomean)
//执行sparkSql语句
val res: DataFrame = sparkSession.sql("select ge(id) result from v_range")
res.show()
sparkSession.stop()
}
}
class Geom extends UserDefinedAggregateFunction {
//输入类型
override def inputSchema: StructType = StructType(List(StructField("value", DoubleType)))
//中间数据
override def bufferSchema: StructType = StructType(List(
StructField("product", DoubleType),
StructField("counts", LongType)
))
//最终返回结果类型
override def dataType: DataType = DoubleType
//确保一致性
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 1.0
buffer(1) = 0L
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Double = {
math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
}
}