测试数据:
{“name”:“zhangsan”,“age”:20}
{“name”:“lisi”,“age”:21}
{“name”:“wangwu”,“age”:22}
{“name”:“zhaoliu”,“age”:23}
{“name”:“tianqi”,“age”:24}
测试结果:
±----±-----+
|count|ageavg|
±----±-----+
| 5| 22.0|
±----±-----+
package com.spark.week3
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
object Moni2 extends UserDefinedAggregateFunction{
// 聚合函数的输入数据结构
override def inputSchema: StructType =StructType(StructField("input", LongType) :: Nil)
// 缓存区数据结构
override def bufferSchema: StructType = StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
// 聚合函数返回值数据结构
override def dataType: DataType = DoubleType
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
// 给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (input.isNullAt(0)) return
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
// 合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
// 计算最终结果
override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble/buffer.getLong(1)
}
object StudentAvg extends App {
val spark = SparkSession.builder().master("local").appName("moni2").getOrCreate()
val file = spark.read.json("G:/data/week3monitwo.json")
//临时表
file.createTempView("stu" )
//注册
spark.udf.register("student",Moni2)
// 求所有人的平均年龄
spark.sql("select count(1) count,student(age) ageavg from stu").show()
spark.close()
}