使用很少,绝大部分udaf可以用更简洁的 udf 代替
import org.apache.log4j.Logger
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataTypes, StringType, StructField, StructType}
object UDAFTest {
private val logger = Logger.getLogger(this.getClass)
val spark: SparkSession = SparkSession
.builder()
.appName("local-test")
.master("local[4]")
//.enableHiveSupport()
//.config("spark.shuffle.service.enabled", true)
//.config("spark.driver.maxResultSize", "4G")
//.config("spark.sql.parquet.writeLegacyFormat", true)
.getOrCreate()
import spark.implicits._
spark.sparkContext.setLogLevel("warn")
def main(args: Array[String]): Unit = {
val t1 = System.currentTimeMillis()
val source = Array(
("11111111", "2021-03-29 00:00:00", "3"),
("11111111", "2021-03-29 00:01:00", "5"),
("11111111", "2021-03-29 00:02:00", "7"),
("11111111", "2021-03-29 00:03:00", "9"),
("11111111", "2021-03-29 00:04:00", "11"),
("22222222", "2021-03-29 00:00:00", "3"),
("22222222", "2021-03-29 00:01:00", "5"),
("22222222", "", "7"),
("22222222", "2021-03-29 00:03:00", "9"),
("22222222", "2021-03-29 00:04:00", "11")
)
val rawRDD = spark.sparkContext.parallelize(source)
val rawDF = rawRDD.map(x => {
(x._1, x._2, x._3)
}).toDF("use_id", "occur_time", "value")
.selectExpr(
"use_id",
"occur_time",
"cast(value as double) as value"
)
rawDF.printSchema()
rawDF.show(false)
testMaxWhen(rawDF)
val t2 = System.currentTimeMillis()
logger.warn("======== run time is: " + ((t2 - t1) / 1000) )
spark.stop()
}
private def testMaxWhen(df: DataFrame) = {
spark.udf.register("max_when", new MaxWhenUdaf)
val tmpDF = df.groupBy("use_id")
.agg(
callUDF("max_when", $"value", lit(11)).as("max_when_test")
)
tmpDF.printSchema()
tmpDF.show(false)
}
class MaxWhenUdaf extends UserDefinedAggregateFunction {
// 输入参数的数据类型
override def inputSchema = {
//DataTypes.createStructField("value_v", DataTypes.StringType, true)
//StructType(Seq(StructField("value_v", StringType)))
//StructType(StructField("value_v", StringType) :: Nil)
new StructType()
.add("value_v", "double")
.add("value_tag", "double")
}
// buffer中的数据类型
override def bufferSchema = {
new StructType()
.add("buffer_v", "double")
}
// 返回值的类型
override def dataType = {
DataTypes.StringType
}
// 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
override def deterministic = {
true
}
// 初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Double.MinValue
}
/**
*
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* buffer.getAs[Double](0)获取的是上一次聚合后的值
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val calValue = input.getAs[Double](0)
val calFlag = input.getAs[Double](1)
if(calValue != null && calFlag != null && !"".equals(calValue) && !"".equals(calFlag)){
if(calValue != calFlag
&& calValue> buffer.getAs[Double](0) ) {
buffer.update(0, calValue)
}
}
}
/**
* 合并其他部分结果
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if(buffer1.getDouble(0) > buffer2.getDouble(0) ){
buffer1.update(0, buffer1.getDouble(0))
}else{
buffer1.update(0, buffer2.getDouble(0))
}
}
// 计算逻辑
override def evaluate(buffer: Row) = {
buffer.getDouble(0).toString
}
}
}
另外这几篇也不错:
https://zhuanlan.zhihu.com/p/25587189
https://blog.csdn.net/kwu_ganymede/article/details/50462020
https://www.jianshu.com/p/ca3ce8baeffb
https://blog.csdn.net/fengfengchen95/article/details/88681780