Spark SQL自定义聚合函数(弱类型)

UDAF的使用(弱类型 基于DataFrame)

用户自定义UDAF聚合函数需要实现以下两个步骤:
1、弱类型聚合函数
继承UserDefinedAggregateFunction
2、注册为函数:ss.udf.register(“avgCus”, new CusAvgFun)

package SparkSQL

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
  * 自定义UDAF函数
  * 聚合函数 对多行数据生效  进来多行  输出一行
  */
object TestUDAFFunction {
  def main(args: Array[String]): Unit = {
    val ss = SparkSession.builder().master("local").appName("UDAF Function").getOrCreate()
    val sc = ss.sparkContext
    import ss.implicits._
    var data: DataFrame = sc.parallelize(Array(("zs", 15), ("ls", 20), ("ww", 18), ("ml", 25), ("zq", 30))).toDF("name", "ageAndheigth")
    //注册聚合函数
    ss.udf.register("avgCus", new CusAvgFun)
    data.createGlobalTempView("student")
    //聚合函数的使用
    ss.sql("select avgCus(ageAndheigth) as valuecu from global_temp.student").show()
  }
}

/**
  * 用户自定义聚合函数
  */
class CusAvgFun extends UserDefinedAggregateFunction {
  //输入数据的类型  类似于创建dataframe时候指定列的数据类型  在整个底层计算中应该是以row传递数据
  override def inputSchema: StructType = StructType(StructField("data", LongType) :: Nil)

  //缓冲区中值的数据类型  这里就是你在计算过程中所需要的数据的类型   如果求平均数在这里就是两个中间值 封装成Row传值
  override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

  //输出数据的类型
  override def dataType: DataType = DoubleType

  //如果此函数是确定性的,即给定相同的输入,返回true,始终返回相同的输出。
  override def deterministic: Boolean = true

  //初始化给定的聚合缓冲区,即聚合缓冲区的零值。约定应该是在两个初始缓冲区上应用合并函数只应返回初始缓冲区本身。
  //这里说一下 这里是中间值的数据  因为刚刚这中间值一共设置了两个   所以这里要按照顺序更新两个值
  //分别是sum值和buffer值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0l)
    buffer.update(1, 0l)
  }

  //更新值 buffer代表的是缓冲区的值  input代表的是新输入的数据的值
  //其中input代表的是新输入的数据 封装成Row,其中Row中有一个值 这个值就是刚定义的输入数据的类型 long类型的data
  //buffer中有两个值  分别是定义的缓冲区的值 一个sum 一个count
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getLong(0) + input.getLong(0))
    buffer.update(1, buffer.getLong(1) + 1l)
  }

  //合并两个聚合缓冲区并将更新的缓冲区值存储回“buffer1”。当我们将两个部分聚合的数据合并在一起时,会调用此方法。
  //会将一个缓冲区的数据拉取到另一个缓冲区完成合并。所以在第二个参数buffer2就是另外一个缓冲区的数据只不过封装成了Row
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
  }

  //根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
  //这里的是指最终缓冲区的数据  有两个值  分区是刚定义的缓冲区的数据类型 sum和count 按顺序的
  override def evaluate(buffer: Row): Any = {
    (buffer.getLong(0) / buffer.getLong(1)).toDouble
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值