Spark 之 UDAF

一、UDAF简介
先解释一下什么是UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

关于UDAF的一个误区
我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:
select max(foo) from foobar group by bar;
表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

select max(foo) from foobar;
这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

二、UDAF使用
继承UserDefinedAggregateFunction
使用UserDefinedAggregateFunction的套路:
1、自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
2、在spark中注册UDAF,为其绑定一个名字
3、然后就可以在sql语句中使用上面绑定的名字调用

实例,求平均年龄
user.json内容如下

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

实现代码

package nj.zb.kb09.sql

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object SparkUDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]")
      .appName("SparkStudy").getOrCreate()

    val df = spark.read.json("in/user.json")
//    df.show()
//    df.printSchema()

    //创建并注册自定义udaf函数
    val myUDAF=new AvgAgeUADFFunction
    spark.udf.register("AvgAge",myUDAF)

    df.createTempView("userInfo")
    val resDF = spark.sql("select sex,round(AvgAge(age),3) from userInfo group by sex")
    resDF.show()
    resDF.printSchema()
  }
}

class AvgAgeUADFFunction extends UserDefinedAggregateFunction {

  // 聚合函数的输入数据结构
  override def inputSchema: StructType = StructType(
//    new StructField("age",LongType)
    StructField("input",LongType)::Nil
  )

  // 缓存区数据结构
  override def bufferSchema: StructType = StructType(
    StructField("sum",LongType)::StructField("count",LongType)::Nil
  )

  // 聚合函数返回值数据结构
  override def dataType: DataType = DoubleType

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true

  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0L
    buffer(1)=0L
  }

  // 给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // 合并聚合函数缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 计算最终结果
  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)
}

输出
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值