需求:计算1-10的几何平均数
需要继承UserDefinedAggregateFunction 并重写方法 含义见注释
package cn.UDAF
import java.lang
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
/**
* 计算1-10的几何平均数
*
* @Author xiaohuli
* @CreateDate 2019/2/6
*/
object UdafTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UdafTest")
.master("local[4]")
.getOrCreate()
//创建Dataset
val range: Dataset[lang.Long] = spark.range(1, 11)
//注册函数
spark.udf.register("gm", GeoMean)
//注册临时视图
range.createTempView("v_range")
//执行sql语句
val result = spark.sql("SELECT gm(id) result FROM v_range")
result.show()
//关闭资源
spark.stop()
}
}
/**
* 自定义UDAF 计算几何平均数
*/
object GeoMean extends UserDefinedAggregateFunction {
//输入数据类型
override def inputSchema: StructType = StructType(List(
//给输入变量定义变量名 和 数据类型
StructField("value", DoubleType)
))
//中间结果的数据类型
override def bufferSchema: StructType = StructType(List(
//相乘之后的积
StructField("product", DoubleType),
//参与运算的数字总数
StructField("counts", LongType)
))
//最终返回结果的数据类型
override def dataType: DataType = DoubleType
//确保一致性 一般用true
override def deterministic: Boolean = true
//指定初始化的值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//乘积的初始值 1.0
buffer(0) = 1.0
//乘数总数的初始值 0L
buffer(1) = 0L
}
//局部聚合 就是每个分区内的运算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将输入值与中间值相乘
buffer(0) = buffer.getDouble(0) * input.getDouble(0)
//参与运算的数+1
buffer(1) = buffer.getLong(1) + 1L
}
//全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//将每个分区的结果进行相乘
buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
//将每个分区内的次数相加为总次数
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Double = {
math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
}
}