/**
* udaf 计算每个评价流水的每个问卷的平均得分
*/
class CalScoreAvl extends UserDefinedAggregateFunction{
//输入字段的个数和类型 index和score的名称不代表实际的字段 只是一个占位说明
override def inputSchema: StructType = StructType(Array(StructField("index",StringType),StructField("score",DoubleType)))
//中间存储的变量定义
//这里定义三个
// sum 用来存储总分
// count 用来存储个数
// parse 用来存储index的字符串 判断是否重复
override def bufferSchema: StructType = StructType(Array(StructField("sum",DoubleType),StructField("count",IntegerType),StructField("parse",StringType)))
//udaf最终的返回值
override def dataType: DataType = DoubleType
//默认true就OK
override def deterministic: Boolean = true
//初始化 上面定义的三个中间结构变量
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0)
buffer.update(2,"")
}
// 逻辑模块
// 主要写三个变量通过什么方式去计算的逻辑
// 这里是字段传入的index没有出现过 而且 score不是 0的话 sum 累加 count + 1 parse的index字符串拼接
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!buffer.getString(2).contains(input.getString(0)) && input.getDouble(1) != 0){
buffer(0) = buffer.getDouble(0) + input.getDouble(1)
buffer(1) = buffer.getInt(1) + 1
buffer(2) = buffer.getString(2) + "_" + input.getString(0)
}
}
//分布式合并规则
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
buffer1(2) = buffer1.getString(2) +""+ buffer2.getString(2)
}
//最终输出的结果计算方式
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getInt(1)
}
}
spark udaf 操作说明
最新推荐文章于 2023-03-16 14:05:25 发布