package com._51doit.spark.SQLUDF
import org.apache.spark.sql.SparkSession
object UDAFMyAvg {
/**
* Author: 实现avg
* Date: 2020/8/21
* Description:
*/
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName(this.getClass.getSimpleName).appName("local[*]").master("local[*]").getOrCreate()
val df = spark.read.option("header", true).option("inferSchema", true).csv("data\\AVG.csv")
df.createTempView("v_avg")
import spark.implicits._
//实现spark3.0之前的UserDefinedAggregateFunction
val myAvg = spark.udf.register("myAvg", new MyAvgDemo)
spark.sql(
"""
|SELECT myAvg(salary),dept from v_avg group by dept
|
|""".stripMargin).show()
}
}
package cn._51doit.spark.day12
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType}
class MyAvgFunction extends UserDefinedAggregateFunction {
//输入数据的类型
override def inputSchema: StructType = StructType(List(
StructField("in", DoubleType)
))
//中间要缓存的数据类型
override def bufferSchema: StructType = StructType(List(
StructField("total", DoubleType),
StructField("amount", IntegerType)
))
//返回的数据类型
override def dataType: DataType = DoubleType
//输入的类型和返回的类型是否一样
override def deterministic: Boolean = true
//初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0 //薪水的初始值
buffer(1) = 0 //员工的数量初始值
}
//在每一个分区内局部聚合的方法,每一个组,每处理一条数据调用一次该方法
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getDouble(0) + input.getDouble(0) //薪水
buffer(1) = buffer.getInt(1) + 1 //人数加一
}
//全局聚合是调用的函数
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
}
//计算最终的结果的方法
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getInt(1)
}
}