Spark UDAF用户自定义聚合函数


UDAF的特点就是:N:1,目的就是为了做聚合(group by)

处理流程

首先准备好数据源:
在这里插入图片描述
这里我们人为的将其分为2个分区:

在这里插入图片描述
按照group by字段进行分组:

在这里插入图片描述

在每一个partition内的每一个分组内,按照目标字段(以age为例)操作,执行update方法

在这里插入图片描述
group by字段相同的为一组,拉取数据:

在这里插入图片描述
最后分别针对group by分成的几组,执行merge操作!

弱类型

/**
  * 弱类型自定义UDAF
  */

class MyAvgUDAF extends  UserDefinedAggregateFunction{
  /**
    * 输入列的类型:age:Int
    */
  override def inputSchema: StructType = {
    val fields: Array[StructField] = Array(
      StructField("input name", IntegerType)
    )
    StructType(fields)
  }

  /**
    * 临时变量的类型:sum和count
    */
  override def bufferSchema: StructType = {
    val fields: Array[StructField] = Array(
      StructField("sum", LongType),
      StructField("count", LongType)
    )
    StructType(fields)
  }

  /**
    * 返回值类型
    */
  override def dataType: DataType = DoubleType

  /**
    * 一致性:同样的输入,是否返回同样的结果
    */
  override def deterministic: Boolean = true

  /**
    * 中间变量的值:sum=0,count=0
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化sum的值为0
    buffer(0)=0L
    //初始化count的值为0
    buffer(1)=0L
  }

  /**
    * 累加,类似Combiner,在每个task内执行
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //sum初始值
    val sum: Long = buffer.getAs[Long](0)
    //count初始值
    val count: Long = buffer.getAs[Long](1)

    //取出目标:age
    val age: Int = input.getAs[Int](0)

    //更新sum
    buffer.update(0,sum+age)
    //更新count
    buffer.update(1,count+1)
  }

  /**
    * 合并所有task中,改分组的所有sum和count,类似Reducer
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //取出之前统计的sum
    val sum: Long = buffer1.getAs[Long](0)
    //取出之前统计的count
    val count: Long = buffer1.getAs[Long](1)

    //当前task统计的sum
    val taskSum: Long = buffer2.getAs[Long](0)
    //当前task统计的count
    val taskCount: Long = buffer2.getAs[Long](1)

    //merge合并
    buffer1.update(0,sum+taskSum)
    buffer1.update(1,count+taskCount)
  }

  /**
    * 计算得到最终结果
    */
  override def evaluate(buffer: Row): Any = {
    //最终sum
    val sum: Long = buffer.getAs[Long](0)
    //最终count
    val count: Long = buffer.getAs[Long](1)

    sum.toDouble/count
  }
}

调用:

    val list = List(
      ("lisi",20,"man"),
      ("wangwu",30,"woman"),
      ("zhaoliu",12,"man"),
      ("lilei",40,"woman"),
      ("hanmeimei11",30,"woman"),
      ("hanmeimei22",80,"woman"),
      ("hanmeimei33",90,"man"),
      ("hanmeimei44",100,"man"))
    val rdd: RDD[(String,Int,String)] = spark.sparkContext.parallelize(list,2)
    rdd.mapPartitionsWithIndex((index,iter)=>{
      println(s"index:${index}    data:${iter.toList}")
      iter
    }).collect()
    import spark.implicits._
    val df: DataFrame = rdd.toDF("name","age","sex")
    //注册成表
    df.createOrReplaceTempView("emp")
    //创建自定义UDAF
    val myavg = new MyAvgUDAF
    //注册
    spark.udf.register("myavg",myavg)
    //执行
    spark.sql(
      """
        |select  sex,myavg(age) avg_age
        |from emp
        |group by sex
      """.stripMargin).show()

强类型

强类型更简洁方便

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

/**
  * 强类型自定义UDAF
  */

//样例类,作为中间变量:sum、count
case class AvgAgeBuff(var sum:Long,var count:Long)

/**
  * Aggregator[IN,BUFF,OUT]
  *   IN: 函数参数类型
  *   BUFF: 计算的中间变量类型
  *   OUT: 函数的最终结果类型
  */
class MyAvgAggregator extends Aggregator[Int,AvgAgeBuff,Double]{
  override def zero: AvgAgeBuff = AvgAgeBuff(0L,0L)

  /**
    * 针对每个partition的每个组,进行合并
    */
  override def reduce(b: AvgAgeBuff, age: Int): AvgAgeBuff = {
    //原来的sum + 现在的age
    //原来的count + 1
    AvgAgeBuff(b.sum+age,b.count+1)
  }

  /**
    * 合并相同组
    */
  override def merge(b1: AvgAgeBuff, b2: AvgAgeBuff): AvgAgeBuff = {
    AvgAgeBuff(b1.sum+b2.sum,b1.count+b2.count)
  }

  /**
    * 最终结果
    */
  override def finish(reduction: AvgAgeBuff): Double = reduction.sum.toDouble/reduction.count

  /**
    * 指定中间变量的编码方式
    */
  override def bufferEncoder: Encoder[AvgAgeBuff] = Encoders.product[AvgAgeBuff]

  /**
    * 指定最终结果的编码方式
    */
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

调用:

  val list = List(
      ("lisi",20,"man"),
      ("wangwu",30,"woman"),
      ("zhaoliu",12,"man"),
      ("lilei",40,"woman"),
      ("hanmeimei11",30,"woman"),
      ("hanmeimei22",80,"woman"),
      ("hanmeimei33",90,"man"),
      ("hanmeimei44",100,"man"))
    val rdd: RDD[(String,Int,String)] = spark.sparkContext.parallelize(list,2)
    rdd.mapPartitionsWithIndex((index,iter)=>{
      println(s"index:${index}    data:${iter.toList}")
      iter
    }).collect()
    import spark.implicits._
    val df: DataFrame = rdd.toDF("name","age","sex")
    //注册成表
    df.createOrReplaceTempView("emp")
    //创建自定义UDAF
    val aggregator = new MyAvgAggregator
    import org.apache.spark.sql.functions._
    //转换udaf对象
    val func: Any = udaf(aggregator)
    //注册
    spark.udf.register("myavg",func)
    spark.sql(
      """
        |select  sex,myavg(age) avg_age
        |from emp
        |group by sex
      """.stripMargin).show()
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值