用户自定义聚合函数
package doc.df
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
/**
* @Program: doc.df
* @Author: huangwei
* @Date: 2019/9/16 17:34
* @description: 用户自定义聚合函数
* 无类型的用户自定义函数(Untyped User-Defined Aggregate Functions)
* 实现无类型用户自定义聚合函数需要继承抽象类UserDefinedAggregateFunction,并重写该类的8个函数
*/
object UserDefindUntypedAggregate {
object MyAverage extends UserDefinedAggregateFunction{
// 1、inputSchema 定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的列表
// 这里定义salary列的Schema,首先使用StructField声明salary列的名字salaryColumn,数据类型为Long,这里只输入salary这一列,所以StructField构成的列表只有一个元素
// ::是Scala的操作符,与空集合Nil操作后生成一个列表
override def inputSchema: StructType = StructType(StructField("salaryColumn",LongType)::Nil)
// 2、bufferSchema 事实上需要计算salary平均值的时候,需要用到salary的总和sum和总个数count这样的中间数据,那么就使用bufferSchema来定义
override def bufferSchema: StructType = StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
// 3、dataType 我们需要自定义聚合函数最终数据类型进行说明,使用dataType函数,这里salary的类型为Double类型
override def dataType: DataType = DoubleType
// 4、deterministic 用户对输入数据进行一致性检验,是一个布尔值,当为True时,表示对于同样的输入会得到同样的输出,因为对于同样的Salary输入,肯定要得到相同的Salary平均著,所以定义为true
override def deterministic: Boolean = true
// 5、initialize 用于初始化缓存数据,salary的缓存数据有两个:sum和count,需要初始化sum为0L,count为0L
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 6、update 当有新的输入数据时,更新缓存变量,这里有新的salary输入时,需要更新sum值,并将count加1
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 输入非空
if (!input.isNullAt(0)){
buffer(0) = buffer.getLong(0) + input.getLong(0) // sum = sum+输入的salary
buffer(1) = buffer.getLong(1) + 1 // count = count + 1
}
}
// 7、merge 将更新的缓存变量存入到缓存中
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 8、evalute 用于计算最后的结果,这里用于计算平均值
override def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Spark SQL user-defined DataFrames aggregation example")
.master("local")
.getOrCreate()
// 注册名为myAverage的自定义集成算子MyAverage
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("E:\\IdeaProjects\\SparkProject\\src\\main\\resources\\employess.json")
df.createOrReplaceTempView("employee")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val avg_salary = spark.sql("SELECT myAverage(salary) as average_salary FROM employee")
avg_salary.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
}
}