UDAF的特点就是:N:1,目的就是为了做聚合(group by)
处理流程
首先准备好数据源:
这里我们人为的将其分为2个分区:
按照group by字段进行分组:
在每一个partition内的每一个分组内,按照目标字段(以age为例)操作,执行update方法
group by字段相同的为一组,拉取数据:
最后分别针对group by分成的几组,执行merge操作!
弱类型
/**
* 弱类型自定义UDAF
*/
class MyAvgUDAF extends UserDefinedAggregateFunction{
/**
* 输入列的类型:age:Int
*/
override def inputSchema: StructType = {
val fields: Array[StructField] = Array(
StructField("input name", IntegerType)
)
StructType(fields)
}
/**
* 临时变量的类型:sum和count
*/
override def bufferSchema: StructType = {
val fields: Array[StructField] = Array(
StructField("sum", LongType),
StructField("count", LongType)
)
StructType(fields)
}
/**
* 返回值类型
*/
override def dataType: DataType = DoubleType
/**
* 一致性:同样的输入,是否返回同样的结果
*/
override def deterministic: Boolean = true
/**
* 中间变量的值:sum=0,count=0
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化sum的值为0
buffer(0)=0L
//初始化count的值为0
buffer(1)=0L
}
/**
* 累加,类似Combiner,在每个task内执行
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//sum初始值
val sum: Long = buffer.getAs[Long](0)
//count初始值
val count: Long = buffer.getAs[Long](1)
//取出目标:age
val age: Int = input.getAs[Int](0)
//更新sum
buffer.update(0,sum+age)
//更新count
buffer.update(1,count+1)
}
/**
* 合并所有task中,改分组的所有sum和count,类似Reducer
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//取出之前统计的sum
val sum: Long = buffer1.getAs[Long](0)
//取出之前统计的count
val count: Long = buffer1.getAs[Long](1)
//当前task统计的sum
val taskSum: Long = buffer2.getAs[Long](0)
//当前task统计的count
val taskCount: Long = buffer2.getAs[Long](1)
//merge合并
buffer1.update(0,sum+taskSum)
buffer1.update(1,count+taskCount)
}
/**
* 计算得到最终结果
*/
override def evaluate(buffer: Row): Any = {
//最终sum
val sum: Long = buffer.getAs[Long](0)
//最终count
val count: Long = buffer.getAs[Long](1)
sum.toDouble/count
}
}
调用:
val list = List(
("lisi",20,"man"),
("wangwu",30,"woman"),
("zhaoliu",12,"man"),
("lilei",40,"woman"),
("hanmeimei11",30,"woman"),
("hanmeimei22",80,"woman"),
("hanmeimei33",90,"man"),
("hanmeimei44",100,"man"))
val rdd: RDD[(String,Int,String)] = spark.sparkContext.parallelize(list,2)
rdd.mapPartitionsWithIndex((index,iter)=>{
println(s"index:${index} data:${iter.toList}")
iter
}).collect()
import spark.implicits._
val df: DataFrame = rdd.toDF("name","age","sex")
//注册成表
df.createOrReplaceTempView("emp")
//创建自定义UDAF
val myavg = new MyAvgUDAF
//注册
spark.udf.register("myavg",myavg)
//执行
spark.sql(
"""
|select sex,myavg(age) avg_age
|from emp
|group by sex
""".stripMargin).show()
强类型
强类型更简洁方便
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* 强类型自定义UDAF
*/
//样例类,作为中间变量:sum、count
case class AvgAgeBuff(var sum:Long,var count:Long)
/**
* Aggregator[IN,BUFF,OUT]
* IN: 函数参数类型
* BUFF: 计算的中间变量类型
* OUT: 函数的最终结果类型
*/
class MyAvgAggregator extends Aggregator[Int,AvgAgeBuff,Double]{
override def zero: AvgAgeBuff = AvgAgeBuff(0L,0L)
/**
* 针对每个partition的每个组,进行合并
*/
override def reduce(b: AvgAgeBuff, age: Int): AvgAgeBuff = {
//原来的sum + 现在的age
//原来的count + 1
AvgAgeBuff(b.sum+age,b.count+1)
}
/**
* 合并相同组
*/
override def merge(b1: AvgAgeBuff, b2: AvgAgeBuff): AvgAgeBuff = {
AvgAgeBuff(b1.sum+b2.sum,b1.count+b2.count)
}
/**
* 最终结果
*/
override def finish(reduction: AvgAgeBuff): Double = reduction.sum.toDouble/reduction.count
/**
* 指定中间变量的编码方式
*/
override def bufferEncoder: Encoder[AvgAgeBuff] = Encoders.product[AvgAgeBuff]
/**
* 指定最终结果的编码方式
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
调用:
val list = List(
("lisi",20,"man"),
("wangwu",30,"woman"),
("zhaoliu",12,"man"),
("lilei",40,"woman"),
("hanmeimei11",30,"woman"),
("hanmeimei22",80,"woman"),
("hanmeimei33",90,"man"),
("hanmeimei44",100,"man"))
val rdd: RDD[(String,Int,String)] = spark.sparkContext.parallelize(list,2)
rdd.mapPartitionsWithIndex((index,iter)=>{
println(s"index:${index} data:${iter.toList}")
iter
}).collect()
import spark.implicits._
val df: DataFrame = rdd.toDF("name","age","sex")
//注册成表
df.createOrReplaceTempView("emp")
//创建自定义UDAF
val aggregator = new MyAvgAggregator
import org.apache.spark.sql.functions._
//转换udaf对象
val func: Any = udaf(aggregator)
//注册
spark.udf.register("myavg",func)
spark.sql(
"""
|select sex,myavg(age) avg_age
|from emp
|group by sex
""".stripMargin).show()