package com.erongda.bigdata.scala.spark.sql
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.collection.Seq._
import scala.Predef.Set
/**
* 自定义UDAF函数 实现去重效果
*/
object DistinctUDAF extends UserDefinedAggregateFunction {
import scala.collection.immutable.Set
var set: Set[Int] = Set()
/**
* 输入的数据类型
* @return
*/
override def inputSchema: StructType = StructType(
Array(StructField("sal",IntegerType,nullable = true))
)
/**
* 表示的是聚合过程中 缓冲临时变量的数据类型,也是封装在StructType里面
* @return
*/
override def bufferSchema: StructType = StructType(
Array(StructField("sal_dis",IntegerType, nullable = true))
)
/**
* 输出的数据类型
* @return
*/
override def dataType: DataType = IntegerType
/**
*
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化 对聚合中的 临时缓冲的值
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}
/**
* TODO: 针对每个分区数据进行操作的
* 表示的是针对每条数据进行聚合函数以后,对聚合缓冲临时变量的更新
* @param buffer
* 聚合缓冲临时变量
* @param input
* 输入变量
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//获得缓冲临时的变量值
// val saldis = buffer.getSeq(0).toSet
//获取输入传递进来的值
// val inputSal = input.getDouble(0)
// val value_set = (buffer.getSeq(0).toSet + input.getDouble(0)).toSeq
//更新缓冲数据的值
// buffer.update(0, value_set)
buffer(0)=buffer.getInt(0)+input.getInt(0)
set += (input.getInt(0))
}
/**
* TODO: 针对合并所有分区的聚合结果
* 从字面意思看,就是合并的意思
* This is called when we merge two partially aggregated data together.
*
* 表示的是:
* 正对不同分区 聚合的结果 进行合并操作
*
* Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
* 表示将两个聚合缓冲的数据,合并以后并存储到 buffer1中
*
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//分别获得缓冲临时变量的值
// val saldis = buffer1.getSeq(0).toSet
//
// val saldis2 = buffer2.getSeq(0).toSet
//
// val total = (saldis ++ saldis2).toSeq
// //合并更新
// buffer1.update(0, total)
buffer1(0)=buffer1.getInt(0)+buffer2.getInt(0)
set ++= set
}
/**
* 表示的是一句 聚合缓冲的临时变量 计算 聚合最终结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
//a. 获取聚合缓冲数据
// val saldis = buffer.getSeq(0).length
//b.计算并返回
set.size
}
}
自定义UDAF函数实现去重统计效果
最新推荐文章于 2022-11-18 09:05:23 发布