1.弱类型
1.1 自定义类UserAvg
package com.atguigu.bigdata.sql
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
/**
* 求年龄的平均值
* @auth tianmin
* @date 2020-02-28 - 16:28
* @nodes 若类型UDAF函数
*/
class UserAvg extends UserDefinedAggregateFunction {
//输入参数的类型
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
//计算时候的数据结构
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
//函数返回的数据类型
override def dataType: DataType = DoubleType
//函数是否稳定
override def deterministic: Boolean = true
//计算之前缓冲区的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//根据查询结果更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//sum
buffer(0) = buffer.getLong(0) + input.getLong(0)
//count
buffer(1) = buffer.getLong(1) + 1
}
// 将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//count
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
2.应用
package com.atguigu.bigdata.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* @auth tianmin
* @date 2020-02-28 - 17:13
* @nodes
*/
object TestAvg {
def main(args: Array[String]): Unit = {
// 配置参数
val config: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test01")
// SparkSession对象
val spark: SparkSession = SparkSession.builder().config(config).getOrCreate()
// 创建函数
val userAvg = new UserAvg
// 注册函数
spark.udf.register("myAvg",userAvg)
// 导入隐士包
import spark.implicits._
// 读入数据
val frame: DataFrame = spark.read.json("input/user.json")
// 创建临时表
frame.createOrReplaceTempView("user")
// 求平均值
spark.sql("select myAvg(age) from user").show()
// 关闭资源
spark.stop()
}
}