spark dataFrame自定义聚合函数
spark2.4
scala2.11
准备环境
val spark = SparkSession.builder()
.master("local[*]")
.appName(this.getClass.getSimpleName)
.getOrCreate()
val sc = spark.sparkContext
准备数据
//造数据
val dataList = List(("A", List("v1", "v2")),
("A", List("v1", "v3")),
("B", List("v1", "v2")),
("B", List("v3", "v4")),
("B", List("v1", "v3")))
import spark.implicits._
val initDF = sc.parallelize(dataList).toDF("key","v_list")
+---+--------+
|key|v_list |
+---+--------+
|A |[v1, v2]|
|A |[v1, v3]|
|B |[v1, v2]|
|B |[v3, v4]|
|B |[v1, v3]|
+---+--------+
需求
//需求结果
// key all_v_list
// A [v1,v2,v3]
// B [v1,v2,v3,v4]
// all_v_list 不需要排序 去重即可
自定义函数
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.json4s.DefaultFormats
import scala.collection.mutable
object ListAggNoType extends UserDefinedAggregateFunction {
// 聚合函数输入数据结构
override def inputSchema: StructType = StructType(StructField("element", ArrayType(StringType, containsNull = true)) :: Nil)
// 缓存区数据结构,用于计算
override def bufferSchema: StructType = StructType(StructField("buffer", ArrayType(StringType, containsNull = true)) :: Nil)
// 聚合函数输出值数据结构
override def dataType: DataType = ArrayType(StringType, containsNull = true)
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = List.empty[String]
}
// 给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
implicit val defaultFormats: DefaultFormats.type = org.json4s.DefaultFormats
var list: List[String] = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
val inputList: List[String] = input.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
if(inputList.isEmpty){
list = list.:::(Nil).distinct
}else{
list = list.:::(inputList).distinct
}
buffer(0) = list
}
// 合并聚合函数缓冲区(分布式)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val listResult = buffer1.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
val listTemp = buffer2.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
buffer1(0) = (listResult ++ listTemp).distinct
}
// 计算最终返回结果
override def evaluate(buffer: Row): Any = {
val list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].reverse.toList
list
}
}
注册函数
//注册函数
spark.udf.register("ListAggNoType",ListAggNoType)
测试
initDF.createOrReplaceTempView("table_A")
val resDF = spark.sql(
"""
|select
|key,
|ListAggNoType(v_list) as all_v_list
|from
|table_A
|group by key
|
|""".stripMargin)
resDF.show(false)
结果展示
+---+----------------+
|key|all_v_list |
+---+----------------+
|B |[v4, v3, v2, v1]|
|A |[v3, v2, v1] |
+---+----------------+