Spark SQL实现自定义聚合函数

概述

spark自定义聚合函数需要继承一个抽象类UserDefinedAggregateFunction,并需要重写属性和方法:

  • inputSchema:函数的参数列表,不过需要写成StructType的格式
  • bufferSchema:中间结果的类型,比如求和时,a、b、c相加,需要先计算a+b并保存结果ab,然后计算ab+c,这个ab就是中间结果。
  • dataType:返回值结果类型,显示是DataType,也就是org.apache.spark.sql.types包下的那些类
  • deterministic: 结果是否是确定性的,即相同的输入,是否一定会有相同的输出
  • initialize:初始化中间结果,例如求和函数,开始计算前需要先将中间结果赋值为0
  • update(buffer: MutableAggregationBuffer, input: Row):更新中间结果,input是dataframe的一行,buffer是整个分片遍历过来的中间结果。
  • merge(buffer1:MutableAggregationBuffer,buffer2:Row):分片的合并,buffer2一个分片的中间结果,buffer1是整个合并过程的中间结果
  • evaluate(buffer:Row):返回函数结果,buffer是7的合并过程的中间结果buffer1遍历所有分片结束后的结果。

实现自定义聚合函数代码

package cn.demo.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

/**
  * 自定义聚合函数
  */
class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = StructType(StructField("cityInfo", StringType) :: Nil)

  override def bufferSchema: StructType = StructType(StructField("bufferCityInfo", StringType) :: Nil)

// dataType,指的是,函数返回值的类型
  override def dataType: DataType = StringType

  override def deterministic: Boolean = true
  
  // 为每个分组的数据执行初始化操作
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)= ""
  }

  /**
    * 更新
    * 可以认为是,一个一个地将组内的字段值传递进来
    * 实现拼接的逻辑
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 缓冲中的已经拼接过的城市信息串
    var bufferCityInfo = buffer.getString(0)
    // 刚刚传递进来的某个城市信息
    val cityInfo = input.getString(0)

    // 在这里要实现去重的逻辑
    // 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息
    if(!bufferCityInfo.contains(cityInfo)) {
      if("".equals(bufferCityInfo))
        bufferCityInfo += cityInfo
      else {
        // 比如1:北京
        // 1:北京,2:上海
        bufferCityInfo += "," + cityInfo
      }

      buffer.update(0, bufferCityInfo)
    }
  }
   
  // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
  // 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    var bufferCityInfo1 = buffer1.getString(0);
    val bufferCityInfo2 = buffer2.getString(0);

    for(cityInfo <- bufferCityInfo2.split(",")) {
      if(!bufferCityInfo1.contains(cityInfo)) {
        if("".equals(bufferCityInfo1)) {
          bufferCityInfo1 += cityInfo;
        } else {
          bufferCityInfo1 += "," + cityInfo;
        }
      }
    }

    buffer1.update(0, bufferCityInfo1);
  }
  
   // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
  override def evaluate(buffer: Row): Any = {
    buffer.getString(0)
  }

}

实现注册自定义聚合函数

package cn.demo.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object GroupConcatMain {
  def main(args: Array[String]): Unit = {
    // 构建Spark上下文
    val sparkConf = new SparkConf().setAppName("SessionAnalyzer").setMaster("local[*]")

    // 创建Spark客户端
    val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
    val sc = spark.sparkContext

    // 注册自定义函数
    spark.udf.register("concat_long_string", new GroupConcatDistinctUDAF)
  }
}

这样就可以在spark中使用自定义聚合函数了

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值