1.UDF
现有数据的字段包括username和age,要求查询时在username的结果前加上字符串name:
,如name:张三
。
代码如下:
def main(args: Array[String]): Unit = {
//创建上下文环境配置对象
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
//创建 SparkSession 对象
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df: DataFrame = spark.read.json("datas/user.json")
//创建临时表
df.createOrReplaceTempView("user")
//注册udf
spark.udf.register("prefix",(username:String) => "name:"+username)
//应用udf
spark.sql("select prefix(username),age from user").show
spark.stop()
}
2.UDAF
强类型的 Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数Aggregator。
需求:计算平均值的聚合函数
2.1 UDAF-弱类型
- 1.自定义聚合类
class MyAveragUDAF extends UserDefinedAggregateFunction {
//输入数据的结构
override def inputSchema: StructType = {
StructType(
Array(
StructField("age",LongType)
)
)
}
//聚合缓冲区数据的结构
override def bufferSchema: StructType = {
StructType(
Array(
StructField("sum",LongType),
StructField("count",LongType)
)
)
}
//输出数据的结构
override def dataType: DataType = LongType
//稳定性
override def deterministic: Boolean = true
//初始化缓冲区的数据
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0L)
buffer.update(1,0L)
}
//遍历传入数据的时候更新缓冲区的方法
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getLong(0) + input.getLong(0))
buffer.update(1,buffer.getLong(1) + 1)
}
//缓冲区数据合并的数据,更新buffer1的数据
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
}
//计算返回最后的结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
- 2.使用该聚合函数
def main(args: Array[String]): Unit = {
//创建上下文环境配置对象
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
//创建 SparkSession 对象
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//注册自定义类
spark.udf.register("ageAvg",new MyAveragUDAF())
//创建DataFrame
val df: DataFrame = spark.read.json("datas/user.json")
//创建临时表
df.createOrReplaceTempView("user")
//使用自定义函数
spark.sql("select ageAvg(age) from user").show
spark.stop()
}
- 3.结果展示
原始数据:
+---+--------+
|age|username|
+---+--------+
| 18| 张三|
| 19| 李四|
| 20| 王五|
+---+--------+
查询结果:
+-----------------+
|myaveragudaf(age)|
+-----------------+
| 19|
+-----------------+
2.1 UDAF-强类型
弱类型需要通过数据的顺序通过下标索引的方式操作数据,容易出错,强类型可以通过类属性的方式访问数据。
- 1.自定义聚合类
/*
*1.继承package org.apache.spark.sql.expressions.Aggregator,定义范性
IN:输入数据类型
BUF:缓冲区数据类型
OUT:输出数据类型
*2.重写6个方法
*/
case class Buff(var sum:Long,var count:Long)
class MyAveragUDAF extends Aggregator[Long,Buff,Long] {
//缓冲区初始化
override def zero: Buff = {
Buff(0L,0L)
}
//根据输入的数据更新缓冲区
override def reduce(buff: Buff, in: Long): Buff = {
buff.sum = buff.sum + in
buff.count = buff.count + 1
buff
}
//合并数据
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.sum = buff1.sum + buff2.sum
buff1.count = buff1.count + buff2.count
buff1
}
//计算结果
override def finish(reduction: Buff): Long = {
reduction.sum / reduction.count
}
//编码,自定义类和Scala自带类固定写法
override def bufferEncoder: Encoder[Buff] = Encoders.product
//解码,自定义类和Scala自带类固定写法
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
- 2.使用聚合类
def main(args: Array[String]): Unit = {
//创建上下文环境配置对象
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
//创建 SparkSession 对象
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//注册自定义类
spark.udf.register("ageAvg", functions.udaf(new MyAveragUDAF()))
//创建DataFrame
val df: DataFrame = spark.read.json("datas/user.json")
//创建临时表
df.createOrReplaceTempView("user")
//使用自定义函数
spark.sql("select ageAvg(age) from user").show
spark.stop()
}