Spark-SQL UDAF函数(scala)

Aggregator源码

//类型参数: IN – 聚合的输入类型。 
//BUF – 聚合的中间值的类型。 
//OUT – 最终输出结果的类型

abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
  
  //此聚合的零值。应该满足任意 b + 零 = b 的性质
  def zero: BUF
  
  //组个两个值产生一个新的值,为了性能,该函数可能会修改b并将其返回,而不是为 b 构造新对象
  def reduce(b: BUF, a: IN): BUF
  
  //合并两个中间值
  def merge(b1: BUF, b2: BUF): BUF
  
  //变换归约的输出
  def finish(reduction: BUF): OUT
  
  //指定中间值类型的编码器
  def bufferEncoder: Encoder[BUF]
  
  //指定最终输出值类型的编码器。
  def outputEncoder: Encoder[OUT]

编码器:
自定义类型 Case Class 或者元组就使用 Encoders.product 方法;
基本类型就使用其对应名称的方法,如 scalaByte,scalaFloat,scalaShort 等

示例如下:
override def bufferEncoder: Encoder[SumAndCount] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

在这里插入图片描述

1.有类型

import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}
import org.apache.spark.sql.expressions.Aggregator

/***
 * @Author: lzx
 * @Description: 自定义UDAF函数 【有类型】
 * @Date: 2022/12/27
 * UserDefinedAggregateFunction已经在Spark 3.0以上版本过时,建议使用 Aggregator[IN, BUF, OUT]
 * deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
  " via the functions.udaf(agg) method.", "3.0.0")
 **/

object UserDefinedTypedAggregation {
  //输入类型
  case class Employee(name:String,salary:Long)
  //中间聚合类型
  case class Average(var sum:Long,var count:Long)

  object MyAverage2 extends Aggregator[Employee,Average,Double]{

    //1.用于聚合操作的的初始零值
    override def zero: Average = Average(0L,0L)

    //2.分区内聚合
    override def reduce(b: Average, a: Employee): Average = {
      b.sum=b.sum+a.salary
      b.count=b.count+1
      b
    }

    //3.不同分区聚合
    override def merge(b1: Average, b2: Average): Average = {
      b1.sum=b1.sum+b2.sum
      b1.count=b1.count+b2.count
      b1
    }

    //4.输出结果
    override def finish(reduction: Average): Double = {
      reduction.sum.toDouble / reduction.count
    }

    //5.中间类型的编码转换
    override def bufferEncoder: Encoder[Average] = Encoders.product

    //6.输出类型的编码转换
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }

  def main(args: Array[String]): Unit = {
    val session: SparkSession = SparkSession.builder().master("local[12]").appName("").getOrCreate()

    import session.implicits._
    val ds: Dataset[Employee] = session.read.json("file:///C:\\WorkPlace\\spark-lzx\\spark-sql\\src\\input\\employee.json")
      .as[Employee]

    ds.show(false)
    val averageSalary: TypedColumn[Employee, Double] = MyAverage2.toColumn.name("average_salary")
    ds.select(averageSalary)
      .show(false)

    session.close()

  }
}

2.无类型

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator

object UserDefinedUnTypedAggregation{


  //todo 定义聚合操作的中间输出类型
  case class Average(var sum:Long, var count:Long)

  /**
              参数:IN – 输入类型
              BUF – reduce 中间结果类型
              OUT – 输出类型
   **/
  object MyAverage1 extends Aggregator[Long,Average,Double] {

    //用于聚合操作的的初始零值
    override def zero: Average = Average(0,0)

    //同一分区中的 reduce 操作
    override def reduce(b: Average, a: Long): Average ={
      b.sum += a
      b.count +=1
      b
    }

    //不同分区中的 merge 操作
    override def merge(b1: Average, b2: Average): Average = {
      b1.sum += b2.sum
      b1.count +=b2.count
      b1
    }

    //定义最终的输出类型
    override def finish(reduction: Average): Double = {
      reduction.sum.toDouble / reduction.count
    }

    //中间类型的编码转换
    override def bufferEncoder: Encoder[Average] = Encoders.product

    //输出类型的编码转换
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }

  def main(args: Array[String]): Unit = {
    val session: SparkSession = SparkSession.builder().master("local[12]").appName("").getOrCreate()

    session.udf.register("myAverage",functions.udaf(MyAverage1))

    val sc: SparkContext = session.sparkContext
    import session.implicits._
    val df: DataFrame = sc.makeRDD(Seq(("xiaoming", 1000), ("zhaosi", 2000))).toDF("name", "salary")

    df.show(false)

    df.createOrReplaceTempView("tmp")

    session.sql("select myAverage(salary)as average from tmp").show(false)

    session.close()

  }

}

参考 Spark官网

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值