需求:需要通过继承 UserDefinedAggregateFunction 来实现自定义聚合函数。案例:计算一下员工的平均工资
弱类型聚合函数:
package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* 弱类型的
* 计算员工的平均薪资
*/
class AverageSalaryRuo extends UserDefinedAggregateFunction{
//输入的数据的格式
override def inputSchema: StructType = StructType(StructField("salary",IntegerType) :: Nil)
//每个分区中共享的数据变量结构
override def bufferSchema: StructType = StructType(StructField("sum",LongType) :: StructField("count",IntegerType):: Nil)
//输出的数据的类型
override def dataType: DataType = DoubleType
//表示如果有相同的输入是否会存在相同的输出,是:true
override def deterministic: Boolean = true
//初始化的每个分区共享变量
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0
}
//每一个分区的每一条数据聚合的时候进行buffer的更新
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将buffer中的薪资总和的数据进行更新,原数据加上新输入的数据,buffer就类似于resultSet
buffer(0) = buffer.getLong(0) + input.getInt(0)
//每添加一个薪资,就将员工的个数加1
buffer(1) = buffer.getInt(1)+1
}
//将每个分区的输出合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
}
//获取最终的结果
override def evaluate(buffer: Row): Any = {
//计算平均薪资并返回
buffer.getLong(0).toDouble/buffer.getInt(1)
}
}
object AverageSalaryRuo extends App{
val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val data = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json")
data.createOrReplaceTempView("employee")
//注册自定义聚合函数
spark.udf.register("avgSalary",new AverageSalaryRuo)
spark.sql("select avgSalary(salary) from employee").show()
spark.stop()
}
强类型聚合函数:
package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
/**
* 弱类型的
* 计算员工的平均薪资
*/
//对于强类型来说,无非就是借助于样例类
case class Employee(name:String,salary:Long)
case class Average(var sum:Long,var count:Int)
class AverageSalaryQiang extends Aggregator[Employee,Average,Double]{
//初始化方法
override def zero: Average = Average(0L,0)
//一个分区内的聚合调用,类似于update方法
override def reduce(b: Average, a: Employee): Average = {
b.sum = b.sum + a.salary
b.count = b.count + 1
b
}
override def merge(b1: Average, b2: Average): Average = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
//最终的计算结果
override def finish(reduction: Average): Double = {
reduction.sum.toDouble /reduction.count
}
//对buffer编码
override def bufferEncoder: Encoder[Average] = Encoders.product
//对out编码
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object AverageSalaryQiang extends App{
val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
val employee = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json").as[Employee]
employee.show()
employee.createOrReplaceTempView("employee")
//注册自定义函数
val aaa = new AverageSalaryQiang().toColumn.name("aaaa")
spark.sql("select * from employee").show()
//spark.sql("select aaaa(salary) from employee").show()
employee.select(aaa).show()
spark.stop()
}