Spark UDAF的定义与使用

UDAF概述

UDAF定义

//创建class类继承UserDefinedAggregateFunction并重写其中的方法inputSchema、bufferSchema、dataType、deterministic、initialize、update、merge、evaluate
class 类名 extends UserDefinedAggregateFunction{}

UDAF的使用

	//注册自定义UDAF函数
    val 对象名= new 类名
    spark.udf.register("自定义UDAF名称",对象名)

    //在sapark.sql中操作指定列
    val df2: DataFrame = spark.sql("select 字段,UDAF名称(字段) from userinfo group by 字段")

UDAF示例


/*
user.json数据
{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
*/
object SparkUDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("UDAF").getOrCreate()
    import spark.implicits._
    val df: DataFrame = spark.read.json("in/user.json")
    
    //创建并注册自定义UDAF函数
    val function = new MyAgeAvgFunction
    spark.udf.register("myAvgAge",function)
	//创建视图
    df.createTempView("userinfo")
    //查询男女平均年龄
    val df2: DataFrame = spark.sql("select sex,myAvgAge(age) from userinfo group by sex")
    df2.show()
  }
}


//实现UDAF类
//实现的功能是对传入的数值进行累加,并且计数传入的个数,最后相除得到平均数
class MyAgeAvgFunction extends UserDefinedAggregateFunction{

  //聚合函数的输入数据结构
  override def inputSchema: StructType = {
	  new StructType().add(StructField("age",LongType))
  }
  
  //缓存区数据结构
  override def bufferSchema: StructType = {
 	 new StructType().add(StructField("sum",LongType)).add(StructField("count",LongType))
  }
  
  //聚合函数返回值数据结构
  override def dataType: DataType = DoubleType
  
  //聚合函数是否是幂等的,即相同输入是否能得到相同输出
  override def deterministic: Boolean = true

  //设定默认值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //sum
    buffer(0)=0L
    //count
    buffer(1)=0L
  }
  
  //给聚合函数传入一条新数据时所需要进行的操作
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //将传入的数据进行累加
    buffer(0)=buffer.getLong(0)+input.getLong(0)
    //每传入一次计数加一
    buffer(1)=buffer.getLong(1)+1
  }
  
  //合并聚合函数的缓冲区(不同分区)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //不同分区的数据进行累加
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }
  
  //计算最终结果
  override def evaluate(buffer: Row): Any = {
    //将sum/count的得到平均数
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值