import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
object UserDefineFunction3 extends App {
private val session: SparkSession = SparkSession.builder().appName("test").master("local").getOrCreate()
//定义一个List
val list = List(
Score(1, "zs", 99.9),
Score(2, "ls", 89.2),
Score(3, "xm", 29.3),
Score(4, "wx", 79.0),
Score(5, "lq", 69.9),
Score(6, "zl", 49.5)
)
//此处是session是SparkSession对象
import session.implicits._
//list转DataSet对象
val ds = list.toDS()
//创建临时视图
ds.createTempView("tmp")
//注册自定义函数
session.udf.register("myavg", new MyAvgUDAF)
//定义sql语句
val sql =
"""
|select avg(score),
|myavg(score),
|name
|from tmp
|group by name
|""".stripMargin
//执行sql
session.sql(sql).show()
session.stop()
}
case class Score(id: Int, name: String, score: Double)
//自定义UDAF类,继承UserDefinedAggregateFunction类
class MyAvgUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
StructType(Array(StructField("score", DataTypes.DoubleType)))
}
//临时变量的数据类型
override def bufferSchema: StructType = {
StructType(Array(StructField("sum", DataTypes.DoubleType),
StructField("count", DataTypes.IntegerType)))
}
override def dataType: DataType = DataTypes.DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//临时变量是存储在buffer中的,我们定义的buffer有两个元素,第一个是sum,第二个是count
//设置第一个元素的初始值.0d 代表double类型
buffer.update(0, 0d)
//设置第二个元素的初始值 count的初始值为0
buffer.update(1, 0)
}
//分区类累加操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val score: Double = input.getAs(0)
//将新进来的行中的分数累加到第一个元素sum上
buffer.update(0, buffer.getDouble(0) + score)
//将第二个元素count累加1
buffer.update(1, buffer.getInt(1) + 1)
}
//分区间累加操作
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))
buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
}
//计算输出结果
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getInt(1)
}
}
spark自定义UDAF函数
最新推荐文章于 2024-07-25 10:39:07 发布