大数据系列篇-SPARK-SQL用户定义聚合函数
package com.test
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
import org.apache.spark.sql._
//用户定义聚合函数
object SparkSqlUdaf {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("练习SparkSqlUdaf").setMaster("local[*]")
val spark = SparkSession.builder().config(sparkConf).getOrCreate() //使用$转换时
val df = spark.read.json("data/user.json")
df.createOrReplaceTempView("user")
//弱类型
spark.udf.register("avgAge1", new UdafAvg1)
spark.sql("SELECT avgAge1(age) as avgAge1 FROM user").show()
//强类型
spark.udf.register("avgAge2", functions.udaf(new UdafAvg2))
spark.sql("SELECT avgAge2(age) as avgAge2 FROM user").show()
//低版本中UDAF使用DSL的方式
import spark.implicits._
val ds = df.as[User]
val column = new UdafAvg3().toColumn //udaf函数转换列查询
ds.select(column).show()
spark.close()
}
//(弱类型-按位置)
class UdafAvg1 extends UserDefinedAggregateFunction {
//输入
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
//缓冲区
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType),
)
)
}
//输出
override def dataType: DataType = LongType
//函数稳定度
override def deterministic: Boolean = true
//缓冲区initialize 按位置
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
// 或
// buffer.update(0,0L)
// buffer.update(1,0L)
}
//缓冲区update 按位置
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
//缓冲区merge
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + +buffer2.getLong(1))
}
// 计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1)
}
}
//强类型
class UdafAvg2 extends Aggregator[Long, /*输入类型*/ UserBuf, /*缓冲类型*/ Long /*输出类型*/ ] {
//缓冲区initialize
override def zero: UserBuf = {
UserBuf(0L, 0)
}
//按输入更新缓冲区
override def reduce(b: UserBuf, a: Long): UserBuf = {
b.count += 1
b.total += a
b
}
//缓冲区merge
override def merge(b1: UserBuf, b2: UserBuf): UserBuf = {
b1.count += b2.count
b1.total += b2.total
b1
}
//计算
override def finish(reduction: UserBuf): Long = {
reduction.total / reduction.count
}
//缓冲区Encoder
override def bufferEncoder: Encoder[UserBuf] = Encoders.product //定义的类用Encoders.product
//输出Encoder
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
case class UserBuf(var total: Long, var count: Long)
//强类型
class UdafAvg3 extends Aggregator[User, /*输入类型*/ UserBuf, /*缓冲类型*/ Long /*输出类型*/ ] {
//缓冲区initialize
override def zero: UserBuf = {
UserBuf(0L, 0)
}
//按输入更新缓冲区
override def reduce(b: UserBuf, a: User): UserBuf = {
b.count += 1
b.total += a.age
b
}
//缓冲区merge
override def merge(b1: UserBuf, b2: UserBuf): UserBuf = {
b1.count += b2.count
b1.total += b2.total
b1
}
//计算
override def finish(reduction: UserBuf): Long = {
reduction.total / reduction.count
}
//缓冲区Encoder
override def bufferEncoder: Encoder[UserBuf] = Encoders.product //定义的类用Encoders.product
//输出Encoder
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
case class User(var userName: String, var age: Long)
}