SparkSql自定义聚合函数(弱类型DataFrame)求平均值
弱类型DataFrame提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min();
除此之外,用户可以设定自己的自定义聚合函数。
代码测试前请确保各个组件均已安装
与强类型的代码思想差不多一样,但是书写有些变化,可以参考强类型写法:
强类型DataSet
1、环境准备
1、准备json文件:
{"name": "zhangsan","age": 20}
{"name": "lisi","age": 30}
{"name": "wangwu","age": 40}
2、使用IDEA软件,创建maven工程
3、添加pom依赖
2、maven工程pom依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.spark.bigdata</groupId>
<artifactId>sparkSql_Test</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>2.4.4</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>2.4.4</version>
</dependency>
</dependencies>
</project>
3、创建类 AgeAvgFunction.Class
// 1)声明用户自定义聚合函数
// 2)继承UserDefinedAggregateFunction类
// 3)重写方法
class AgeAvgFunction extends UserDefinedAggregateFunction {
// 函数输入的数据结构
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
// 计算时的数据结构(缓冲区的计算结构)
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
// 数据计算完毕之后的结构(函数返回数据类型)
override def dataType: DataType = DoubleType
// 稳定性
override def deterministic: Boolean = true
// 当前计算之前缓冲区的初始化(就是sum和count初始化是什么值)
// 不考虑类型,只考虑结构
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L // 0代表第一个sum 初始化为0
buffer(1) = 0L // 1代表第二个count 初始化为0
}
// 根据查询结构更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
// 将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// sum=缓冲区第一个位置的sum加上缓冲区第二个位置的sum
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
// count
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
4、创建SparkSql_aggregate.Scala
object SparkSql_aggregate {
def main(args: Array[String]): Unit = {
// 创建配置对象
val conf = new SparkConf().setMaster("local[*]").setAppName("SparkSql_aggregate")
// 创建SparkSql的环境对象
val spark: SparkSession = new sql.SparkSession.Builder().config(conf).getOrCreate()
/* 进行转换之前需要引入隐式转换规则
这里Spark不是报名的含义,是SparkSession对象的名字*/
import spark.implicits._
// 创建聚合函数对象
val udaf = new AgeAvgFunction
// 注册聚合函数
spark.udf.register("avgAge", udaf)
// 使用聚合函数
// 读取数据,构建DF
val df: DataFrame = spark.read.json("input/sparksql.json")
// 数据视图
df.createOrReplaceTempView("user")
// Sql语句
spark.sql("select avgAge(age) from user").show()
// 释放资源
spark.stop()
}
}
5、运行SparkSql_Aggregate.Scala
得到结果: