一、UDF
一进一出函数
/**
语法:SparkSession.udf.register(func_name: String, op: T => K)
*/
object TestSparkSqlUdf {
def main(args: Array[String]): Unit = {
// 创建 sparksql 环境对象
val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(conf).getOrCreate()
// 引入环境对象中的隐式转换
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
/*
需求:给 username 字段的每个值添加前缀
*/
spark.udf.register("prefixName", name => "Name: " + name)
df.createOrReplaceTempView("user")
spark.sql("select prefixName(username), age from user").show()
// 关闭环境
spark.close()
}
}
二、UDAF
多进一出函数,即聚合函数
1. 弱类型函数
/**
自定义步骤:
1.继承 UserDefinedAggregateFunction 抽象类(已过时)
2.重写 8 个方法
*/
object TestSparkSqlUdaf {
def main(args: Array[String]): Unit = {
// 创建 sparksql 环境对象
val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(conf).getOrCreate()
// 引入环境对象中的隐式转换
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
/*
需求:自定义求年龄平均值的udaf函数
*/
val myAvgUdaf = new MyAvgUdaf()
spark.udf.register("ageAvg", myAvgUdaf)
df.createOrReplaceTempView("user")
spark.sql("select ageAvg(age) from user").show()
// 关闭环境
spark.close()
}
}
// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends UserDefinedAggregateFunction {
// 输入数据的结构类型
def inputSchema: StructType = {
// StructType 是样例类
StructType(Array(
// StructField 是样例类,必传参数 name: String, dataType: DataType
StructField("age", LongType)
))
}
// 缓冲区的结构类型
def bufferSchema: StructType = {
StructType(Array(
StructField("totalAge", LongType),
StructField("count", LongType)
))
}
// 输出数据的结构类型
def dataType: DataType = DoubleType
// 函数稳定性
def deterministic: Boolean = true
// 缓冲区初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
}
// 接收输入数据更新缓冲区数据
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val totalAge = buffer.getLong(0)
val count = buffer.getLong(1)
val age = input.getLong(0)
buffer.update(0, totalAge + age)
buffer.update(1, count + 1)
}
// 合并缓冲区
def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
// 计算最终结果
def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
2. 强类型函数
2.1 Spark3.0 之前
/**
自定义步骤:
1.继承 Aggregator 抽象类,定义泛型
IN:输入数据类型
BUF:缓冲区类型
OUT:输出数据类型
2.重写 6 个方法
*/
object TestSparkSqlUdaf1 {
def main(args: Array[String]): Unit = {
// 创建 sparksql 环境对象
val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(conf).getOrCreate()
// 引入环境对象中的隐式转换
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
/*
需求:自定义求年龄平均值的udaf函数
*/
// Spark3.0 之前的强类型UDAF函数必须在 DSL 语法中使用
val ds = df.as[User]
// 将UDAF函数对象转换成 DSL 语法中的查询列
val col: TypedColumn[User, Double] = new MyAvgUdaf().toColumn
ds.select(col).show()
// 关闭环境
spark.close()
}
}
// 定义封装输入的一行数据的类
case class User(username: String, age: Long)
// 定义缓冲区类
case class Buff(var totalAge: Long, var count: Long)
// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends Aggregator[User, Buff, Long] {
// 缓冲区初始化
override def zero: Buff = Buff(0L, 0L)
// 根据输入数据更新缓冲区数据
override def reduce(buff: Buff, in: User): Buff = {
buff.totalAge = buff.totalAge + in.age
buff.count = buff.count + 1
buff
}
// 合并缓冲区
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.totalAge = buff1.totalAge + buff2.totalAge
buff1.count = buff1.count + buff2.count
buff1
}
// 计算最终结果
override def finish(buff: Buff): Double = {
buff.totalAge.toDouble/buff.count
}
//DataSet 默认的编解码器,用于序列化,固定写法
//自定义类型是 product
// 缓冲区编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
// 输出数据编码操作
// 自带类型根据类型选择
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
2.2 Spark3.0 之后
/**
自定义步骤:
1.继承 Aggregator 抽象类,定义泛型
IN:输入数据类型
BUF:缓冲区类型
OUT:输出数据类型
2.重写 6 个方法
*/
object TestSparkSqlUdaf1 {
def main(args: Array[String]): Unit = {
// 创建 sparksql 环境对象
val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(conf).getOrCreate()
// 引入环境对象中的隐式转换
import spark.implicits._
val df: DataFrame = spark.read.json("data/user.json")
/*
需求:自定义求年龄平均值的udaf函数
*/
// Spark3.0 之后的强类型UDAF可以在 SQL 语法中使用
val myAvgUdaf = new MyAvgUdaf()
// 注册函数时需要使用 functions.udaf(func) 包装转换
spark.udf.register("ageAvg", functions.udaf(myAvgUdaf))
df.createOrReplaceTempView("user")
spark.sql("select ageAvg(age) from user").show()
// 关闭环境
spark.close()
}
}
// 定义缓冲区类
case class Buff(var totalAge: Long, var count: Long)
// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends Aggregator[Long, Buff, Long] {
// 缓冲区初始化
override def zero: Buff = Buff(0L, 0L)
// 根据输入数据更新缓冲区数据
override def reduce(buff: Buff, in: Long): Buff = {
buff.totalAge = buff.totalAge + in
buff.count = buff.count + 1
buff
}
// 合并缓冲区
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.totalAge = buff1.totalAge + buff2.totalAge
buff1.count = buff1.count + buff2.count
buff1
}
// 计算最终结果
override def finish(buff: Buff): Double = {
buff.totalAge.toDouble/buff.count
}
//DataSet 默认的编解码器,用于序列化,固定写法
//自定义类型是 product
// 缓冲区编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
// 输出数据编码操作
// 自带类型根据类型选择
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}