SparkSql自定义聚合函数(强类型Dataset)求平均值
强类型的Dataset提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min();
除此之外,用户可以设定自己的自定义聚合函数。
代码测试前请确保各个组件均已安装
1、环境准备
1、准备json文件:
{"name": "zhangsan","age": 20}
{"name": "lisi","age": 30}
{"name": "wangwu","age": 40}
2、使用IDEA软件,创建maven工程
3、添加pom依赖
2、maven工程pom依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.spark.bigdata</groupId>
<artifactId>sparkSql_Test</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>2.4.4</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>2.4.4</version>
</dependency>
</dependencies>
</project>
3、创建类 MyAgeAvgFunctionClass.Class
既然是强类型,可能有case类
case class UserBean(name: String, age: BigInt)
case class AvgBuffer(var sum: BigInt, var count: Int)
// 1)声明用户自定义聚合函数(强类型)
// 2)继承类Aggregator类,设定泛型【传入的类型,缓冲区计算的类型,返回的类型】
// 3)重写方法
class MyAgeAvgFunctionClass extends Aggregator[UserBean, AvgBuffer, Double] {
// 初始化
override def zero: AvgBuffer = {
AvgBuffer(0, 0)
}
/**
* 聚合数据
* @param b
* @param a
* @return
*/
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
b.sum = b.sum + a.age
b.count = b.count + 1
b
}
/**
* 缓冲区的合并
* @param b1
* @param b2
* @return
*/
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
// 计算
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble / reduction.count
}
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
4、创建SparkSql_Aggregate_Class.Scala
import org.apache.spark.{SparkConf, sql}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, TypedColumn}
object SparkSql_aggregate_Class {
def main(args: Array[String]): Unit = {
// 创建配置对象
val conf = new SparkConf().setMaster("local[*]").setAppName("SparkSql_aggregate")
// 创建SparkSql的环境对象
val spark: SparkSession = new sql.SparkSession.Builder().config(conf).getOrCreate()
/*
进行转换之前需要引入隐式转换规则
这里Spark不是包名的含义,是SparkSession对象的名字
*/
import spark.implicits._
// 创建聚合函数的对象
val udaf = new MyAgeAvgFunctionClass
// 将聚合函数转换为查询列
val avgcol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avgAge")
// 读取数据
val df: DataFrame = spark.read.json("input/sparksql.json")
// 变成DS
val userds: Dataset[UserBean] = df.as[UserBean]
// 应用函数
userds.select(avgcol).show()
// 释放资源
spark.stop()
}
}
5、运行SparkSql_Aggregate_Class.Scala
得到结果: