Spark SQL--UDAF函数

需求:需要通过继承 UserDefinedAggregateFunction 来实现自定义聚合函数。案例:计算一下员工的平均工资

弱类型聚合函数:

package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
  * 弱类型的
  * 计算员工的平均薪资
  */
class AverageSalaryRuo extends UserDefinedAggregateFunction{
  //输入的数据的格式
  override def inputSchema: StructType = StructType(StructField("salary",IntegerType) :: Nil)
  //每个分区中共享的数据变量结构
  override def bufferSchema: StructType = StructType(StructField("sum",LongType) :: StructField("count",IntegerType):: Nil)
  //输出的数据的类型
  override def dataType: DataType = DoubleType
  //表示如果有相同的输入是否会存在相同的输出,是:true
  override def deterministic: Boolean = true
  //初始化的每个分区共享变量
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0
  }
  //每一个分区的每一条数据聚合的时候进行buffer的更新
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //将buffer中的薪资总和的数据进行更新,原数据加上新输入的数据,buffer就类似于resultSet
    buffer(0) = buffer.getLong(0) + input.getInt(0)
    //每添加一个薪资,就将员工的个数加1
    buffer(1) = buffer.getInt(1)+1
  }
  //将每个分区的输出合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
  }
  //获取最终的结果
  override def evaluate(buffer: Row): Any = {
    //计算平均薪资并返回
    buffer.getLong(0).toDouble/buffer.getInt(1)
  }
}
object AverageSalaryRuo extends App{
  val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
  val spark = SparkSession.builder().config(conf).getOrCreate()
  val data = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json")
  data.createOrReplaceTempView("employee")
  //注册自定义聚合函数
  spark.udf.register("avgSalary",new AverageSalaryRuo)
  spark.sql("select avgSalary(salary) from employee").show()
  spark.stop()
}

强类型聚合函数:

package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
/**
  * 弱类型的
  * 计算员工的平均薪资
  */
//对于强类型来说,无非就是借助于样例类
case class Employee(name:String,salary:Long)
case class Average(var sum:Long,var count:Int)
class AverageSalaryQiang extends Aggregator[Employee,Average,Double]{
  //初始化方法
  override def zero: Average = Average(0L,0)
  //一个分区内的聚合调用,类似于update方法
  override def reduce(b: Average, a: Employee): Average = {
    b.sum = b.sum + a.salary
    b.count = b.count + 1
    b
  }
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }
  //最终的计算结果
  override def finish(reduction: Average): Double = {
    reduction.sum.toDouble /reduction.count
  }
  //对buffer编码
  override def bufferEncoder: Encoder[Average] = Encoders.product
  //对out编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object AverageSalaryQiang extends App{
  val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
  val spark = SparkSession.builder().config(conf).getOrCreate()
  import  spark.implicits._
  val employee = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json").as[Employee]
  employee.show()
  employee.createOrReplaceTempView("employee")
  //注册自定义函数
  val aaa = new AverageSalaryQiang().toColumn.name("aaaa")
  spark.sql("select * from employee").show()
  //spark.sql("select aaaa(salary) from employee").show()
  employee.select(aaa).show()
  spark.stop()
}

 

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值