spark3.0版本中sparkSQL自定义聚合函数(UDAF)

文章介绍了在Spark3.0版本中如何使用Aggregator类来创建自定义聚合函数,以计算年龄平均值为例,详细说明了继承Aggregator类并重写相关方法的步骤,包括定义输入、缓冲区和输出的泛型类型,以及zero、reduce、merge和finish等方法。
摘要由CSDN通过智能技术生成

spark3.0之前的版本中sparkSQL自定义聚合函数要继承UserDefinedAggregateFunction类,重写8个方法,具体使用方法可参考https://blog.csdn.net/weixin_43866709/article/details/88914871
但是该类是弱类型的,实现逻辑的时候容易出错。
spark3.0版本可以继承Aggregator

1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
IN:输入的数据类型
BUF:缓冲区的数据类型
OUT:输出的数据类型
2.重写方法
3.注册自定义聚合函数
spark.udf.register(“函数名称”,functions.udaf(new MyAgeAvg()))

具体实现案例如下,实现一个简单个求平均值的自定义聚合函数:

package com.zsz.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, SparkSession, functions}

object Spark_SparkSQL_UDAF1 {
  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setMaster("local").setAppName("newUDAF")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    // 隐式转换
    import spark.implicits._

    val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40)))

    val df: DataFrame = rdd.toDF("username", "age")

    df.createTempView("user")

    // 注册自定义聚合函数
    spark.udf.register("MyAgeAvg",functions.udaf(new MyAgeAvg()))

    spark.sql("select MyAgeAvg(age) from user").show()

    spark.close()
  }

  /**
   * 自定义聚合函数类:计算年龄的平均值
   * 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
   *  IN:输入的数据类型
   *  BUF:缓冲区的数据类型
   *  OUT:输出的数据类型
   * 2.重写方法
   */

  case class Buff( var total:Long, var count:Long )

  class MyAgeAvg extends Aggregator[Long,Buff,Long]{
    // 初始值
    override def zero: Buff = {
      Buff(0L,0L)
    }

    // 缓冲区数据计算
    override def reduce(buff: Buff, in: Long): Buff = {
      buff.total += in
      buff.count += 1
      buff
    }

    // 合并缓冲区
    override def merge(buff1: Buff, buff2: Buff): Buff = {
      buff1.total += buff2.total
      buff1.count += buff2.count
      buff1
    }

    // 输出值计算
    override def finish(buff: Buff): Long = {
      buff.total/buff.count
    }

    // 缓冲区编码设置
    override def bufferEncoder: Encoder[Buff] = {
      Encoders.product
    }

    // 输出编码
    override def outputEncoder: Encoder[Long] = {
      Encoders.scalaLong
    }
  }
}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值