参考文章
1.Spark 2.4.0编程指南--Spark SQL UDF和UDAF
https://yq.aliyun.com/articles/680259
https://www.cnblogs.com/zzq-include/p/8758961.html
3.Sparksql UDAF,UDF详解(转)
https://blog.csdn.net/fengfengchen95/article/details/88681780
Base Spark 2.2.0
在SparkSQL中,我们常常使用系统定义的聚合函数 ,如Max,Min 等。我们看下如何在 SparkSQL 中 定义并使用 UDAF。
基本定义
如果实现自定义的udaf 需要继承自 UserDefinedAggregateFunction
下面是需要实现的各自函数以及参数的含义 :
package com.spark.test.offline.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructType}
/**
* Created by szh on 2020/5/31.
* @author szh
*/
class TestSum extends UserDefinedAggregateFunction{
//聚合函数的输入参数数据类型
override def inputSchema: StructType = ???
//中间缓存的数据类型
override def bufferSchema: StructType = ???
//最终输出结果的数据类型
override def dataType: DataType = ???
//???
override def deterministic: Boolean = ???
//初始值,要是DataSet没有数据, 就返回该值
override def initialize(buffer: MutableAggregationBuffer): Unit = ???
/**
*
* @param buffer 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
/**
* 相当于把每个分区的数据进行汇总
* @param buffer1 分区一的数据
* @param buffer2 分区二的数据
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
//计算最终的结果
override def evaluate(buffer: Row): Any = ???
}
3个样例
我们分别实现了3个UDAF
customer_sum
customer_avg
customer_max
customer_sum
package com.spark.test.offline.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
/**
* Created by szh on 2020/5/31.
*/
object CustomerSum extends UserDefinedAggregateFunction {
//聚合函数的输入参数数据类型
def inputSchema: StructType = {
StructType(StructField("inputColumn",LongType) :: Nil)
}
//中间缓存的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)
}
//最终输出结果的数据类型
def dataType: DataType = LongType
def deterministic: Boolean = true
//初始值,要是DataSet没有数据,就返回该值
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
}
/**
*
* @param buffer 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
* @param input
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
if(!input.isNullAt(0)){
buffer(0) = buffer.getLong(0) + input.getLong(0)
}
}
/**
* 相当于把每个分区的数据进行汇总
* @param buffer1 分区一的数据
* @param buffer2 分区二的数据
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
}
//计算最终的结果
def evaluate(buffer: Row): Long = buffer.getLong(0)
}
customer_avg
package com.spark.test.offline.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Created by szh on 2020/5/31.
*/
object CustomerAvg extends UserDefinedAggregateFunction {
//聚合函数的输入参数数据类型
def inputSchema: StructType = {
StructType(StructField("inputColumn",LongType) :: Nil)
}
//中间缓存的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)
}
//最终输出结果的数据类型
def dataType: DataType = DoubleType
def deterministic: Boolean = true
//初始值,要是DataSet没有数据,就返回该值
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
/**
*
* @param buffer 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
* @param input
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
if(!input.isNullAt(0)){
buffer(0) = buffer.getLong(0) + input.getLong(0) // salary
buffer(1) = buffer.getLong(1) + 1 // count
}
}
/**
* 相当于把每个分区的数据进行汇总
* @param buffer1 分区一的数据
* @param buffer2 分区二的数据
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={
buffer1(0) = buffer1.getLong(0) +buffer2.getLong(0) // salary
buffer1(1) = buffer1.getLong(1) +buffer2.getLong(1) // count
}
//计算最终的结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
customer_max
package com.spark.test.offline.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Created by szh on 2020/5/31.
*/
object CustomerMax extends UserDefinedAggregateFunction {
//聚合函数的输入参数数据类型
def inputSchema: StructType = {
StructType(StructField("inputColumn", LongType) :: Nil)
}
//中间缓存的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
//最终输出结果的数据类型
def dataType: DataType = LongType
def deterministic: Boolean = true
//初始值,要是DataSet没有数据,就返回该值
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
}
/**
*
* @param buffer 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
* @param input
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
if (input.getLong(0) > buffer.getLong(0)) {
buffer(0) = input.getLong(0)
}
}
}
/**
* 相当于把每个分区的数据进行汇总
*
* @param buffer1 分区一的数据
* @param buffer2 分区二的数据
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (buffer2.getLong(0) > buffer1.getLong(0)) buffer1(0) = buffer2.getLong(0)
}
//计算最终的结果
def evaluate(buffer: Row): Long = buffer.getLong(0)
}
示例代码
package com.spark.test.offline.udf
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* Created by szh on 2020/5/31.
*/
object SparkSQLUdaf {
def main(args: Array[String]): Unit = {
val conf = new SparkConf
conf
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// .set("spark.kryo.registrationRequired", "true")
//方法一
.registerKryoClasses(
Array(
classOf[User]
, classOf[scala.collection.mutable.WrappedArray.ofRef[_]]
))
val spark = SparkSession
.builder()
.appName("sparkSql")
.master("local[1]")
.config(conf)
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("ERROR")
//user_id,project_id,score
val rddA = sc.parallelize(Seq(
(1, 1, 86.3)
, (1, 2, 88.3)
, (1, 3, 90.3)
, (2, 1, 60.0)
))
spark
.createDataFrame(rddA)
.toDF("user_id", "project_id", "score")
.createTempView("user_info")
spark.udf.register("com_sum", CustomerSum)
val df = spark.sql("SELECT user_id, com_sum(score) FROM user_info GROUP BY user_id ")
df.show()
spark.udf.register("com_avg", CustomerAvg)
val df2 = spark.sql("SELECT user_id, com_avg(score) FROM user_info GROUP BY user_id ")
df2.show()
spark.udf.register("com_max", CustomerMax)
val df3 = spark.sql("SELECT com_max(score) FROM user_info ")
df3.show()
sc.stop()
spark.stop()
}
}
代码输出
+-------+-----------------------------------+
|user_id|customersum$(CAST(score AS BIGINT))|
+-------+-----------------------------------+
| 1| 264|
| 2| 60|
+-------+-----------------------------------+
+-------+-----------------------------------+
|user_id|customeravg$(CAST(score AS BIGINT))|
+-------+-----------------------------------+
| 1| 88.0|
| 2| 60.0|
+-------+-----------------------------------+