spark dataFrame自定义函数 聚合array

1 篇文章 0 订阅

spark dataFrame自定义聚合函数

spark2.4

scala2.11

准备环境

    val spark = SparkSession.builder()
      .master("local[*]")
      .appName(this.getClass.getSimpleName)
      .getOrCreate()
    val sc = spark.sparkContext

准备数据

    //造数据
    val dataList = List(("A", List("v1", "v2")),
      ("A", List("v1", "v3")),
      ("B", List("v1", "v2")),
      ("B", List("v3", "v4")),
      ("B", List("v1", "v3")))
      
    import spark.implicits._
    val initDF = sc.parallelize(dataList).toDF("key","v_list")
+---+--------+
|key|v_list  |
+---+--------+
|A  |[v1, v2]|
|A  |[v1, v3]|
|B  |[v1, v2]|
|B  |[v3, v4]|
|B  |[v1, v3]|
+---+--------+

需求

 //需求结果
    // key    all_v_list
    // A      [v1,v2,v3]
    // B      [v1,v2,v3,v4]
    // all_v_list 不需要排序 去重即可

自定义函数

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.json4s.DefaultFormats
import scala.collection.mutable
object ListAggNoType extends UserDefinedAggregateFunction {
  // 聚合函数输入数据结构
  override def inputSchema: StructType = StructType(StructField("element", ArrayType(StringType, containsNull = true)) :: Nil)
  // 缓存区数据结构,用于计算
  override def bufferSchema: StructType = StructType(StructField("buffer", ArrayType(StringType, containsNull = true)) :: Nil)
  // 聚合函数输出值数据结构
  override def dataType: DataType = ArrayType(StringType, containsNull = true)
  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true
  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = List.empty[String]
  }
  // 给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    implicit val defaultFormats: DefaultFormats.type = org.json4s.DefaultFormats
    var list: List[String] = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
    val inputList: List[String] = input.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
    if(inputList.isEmpty){
      list = list.:::(Nil).distinct
    }else{
      list = list.:::(inputList).distinct
    }
    buffer(0) = list
  }
  // 合并聚合函数缓冲区(分布式)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val listResult = buffer1.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
    val listTemp = buffer2.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
    buffer1(0) = (listResult ++ listTemp).distinct
  }
  // 计算最终返回结果
  override def evaluate(buffer: Row): Any = {
    val list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].reverse.toList
    list
  }
}

注册函数

//注册函数
    spark.udf.register("ListAggNoType",ListAggNoType)

测试

    initDF.createOrReplaceTempView("table_A")
    val resDF = spark.sql(
      """
        |select
        |key,
        |ListAggNoType(v_list) as all_v_list
        |from
        |table_A
        |group by key
        |
        |""".stripMargin)
    resDF.show(false)

结果展示

+---+----------------+
|key|all_v_list      |
+---+----------------+
|B  |[v4, v3, v2, v1]|
|A  |[v3, v2, v1]    |
+---+----------------+

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值