Spark-----UDAF案例
package spark.day03
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object _06TestUDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("udaf")
.getOrCreate()
val df: DataFrame = spark.read.json("sql/emp.json")
df.cache()
df.createTempView("emp")
spark.udf.register("myavg",new MyUDAF)
val sql1=
"""
|select deptno,myavg(sal)
|from emp
|group by deptno
|""".stripMargin
spark.sql(sql1).show()
}
class MyUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType{
Array(
StructField("sal",DoubleType)
)
}
override def bufferSchema: StructType = StructType{
Array(
StructField("sum",DoubleType),
StructField("count",LongType)
)
}
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0D
buffer(1)=0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
buffer.update(1,buffer.getLong(1)+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)/buffer.getLong(1)
}
}
}