核心要义:
聚合是分步骤进行: 先局部聚合,再全局聚合
局部聚合(reduce)的结果是保存在一个局部buffer中的
全局聚合(merge)就是将多个局部buffer再聚合成一个buffer
最后通过(finish)将全局聚合的buffer中的数据做一个运算得出你要的结果
自定义avg
object Demo01_UDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("自定义函数_平均值")
.getOrCreate()
//导入隐式和函数
import spark.implicits._
import org.apache.spark.sql.functions._
//加载数据
val df: DataFrame = spark.read.option("header", true).option("inferSchema", true).csv("sql_data/csv/a.csv")
/*
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- gender: string (nullable = true)
|-- city: string (nullable = true)
*/
df.printSchema()
df.createTempView("temp")
//注册函数
spark.udf.register("my_avg",udaf(new MyAvgFunction))
spark.sql(
"""
|select
| gender,
| avg(age),
| my_avg(age)
|from
| temp
|group by
| gender
|""".stripMargin).show()
}
}
/**
* Aggregator[-IN, BUF, OUT]
* 第一个泛型 输入参数的数据类型
* 第二个泛型 中间缓存的数据类型
* 第三个泛型 输出结果的数据类型
*/
class MyAvgFunction extends Aggregator[Int,(Int,Int),Double] {
//缓存的初始值(默认值)
override def zero: (Int, Int) = (0,0)
//局部聚合的计算逻辑
override def reduce(b: (Int, Int), a: Int): (Int, Int) = (b._1+a,b._2+1)
//全局聚合的计算逻辑
override def merge(b1: (Int, Int), b2: (Int, Int)): (Int, Int) = (b1._1+b2._1,b1._2+b2._2)
//求最终结果的计算逻辑
override def finish(reduction: (Int, Int)): Double = reduction._1/reduction._2
//缓存的类型 结构
override def bufferEncoder: Encoder[(Int, Int)] = Encoders.tuple(Encoders.scalaInt,Encoders.scalaInt)
//输出的类型 结构
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
自定义max
object Demo02_UDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName("自定义聚合函数")
.master("local[*]")
.getOrCreate()
val df: DataFrame = spark.read.option("header", true).option("inferSchema", true).csv("sql_data/csv/a.csv")
df.show()
df.createTempView("temp")
//注册函数
import org.apache.spark.sql.functions._
spark.udf.register("myMax",udaf(new MyMax))
spark.sql(
"""
|select
|gender,
|max(age),
|myMax(age)
|from
|temp
|group by
|gender
|""".stripMargin).show()
}
}
class MyMax extends Aggregator[Int,Int,Int] {
//缓存的初始值
//def zero: BUF
override def zero: Int = 0
//局部计算逻辑
//def reduce(b: BUF, a: IN): BUF
override def reduce(b: Int, a: Int): Int = if (b<a) a else b
//全局计算逻辑
//def merge(b1: BUF, b2: BUF): BUF
override def merge(b1: Int, b2: Int): Int = if (b1<b2) b2 else b1
//最后结果的计算逻辑
//def finish(reduction: BUF): OUT
override def finish(reduction: Int): Int = reduction
//缓存的数据类型
override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
//最终输出的数据类型
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}