自定义函数UDF
val conf = new SparkConf().setMaster("local").setAppName("UDF")
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
val df = spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("prefixName", (name:String) => {
"Name: " + name
})
spark.sql("select age, prefixName(username) from user").show
spark.close()
自定义聚合函数UDAF
——弱类型函数实现
val conf = new SparkConf().setMaster("local").setAppName("UDF")
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
/*
自定义聚合函数类:计算年龄平均值
1. 继承UserDefineAggregateFunction
2. 重写方法
*/
/* UserDefineAggregateFunction默认已经不推荐使用 */
class MyAvgUDAF extends UserDefineAggregateFunction {
// 输入数据的结构
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
// 缓冲区数据的结构:Buffer
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
// 函数计算结果的数据类型:Out
override def dataType: DataType = LongType
// 函数的稳定性
override def deterministic: Booleaj = true
// 缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0L)
buffer.update(1,0L)
}
// 根据输入的值更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
// 缓冲区数据合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
// 计算平均值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1)
}
}
spark.close()
弱类型的函数没有类型的概念,只能根据传参的顺序来操作,这样在使用上容易出错。
强类型可以通过属性的方式实现,这样在实现时就不容易出错。
——强类型函数实现
/*
自定义聚合函数类:计算年龄的平均值
1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
IN:输入的数据类型 User
BUF:缓冲区的数据类型 Buff
OUT:输出的数据类型 User
2. 重写方法(6)
*/
case class User(username: String, age: Long)
case class Buff(var total: Long, var count: Long)
class MyAvgUDAF extends Aggregator[]{
// z & zero : 初始值或零值
// 缓冲区的初始化
override def zero: Buff = {
Buff(0L, 0L)
}
// 根据输入的数据更新缓冲区的数据
override def reduce(b: Any, a: User): Buff = {
buff.total = buff.total + in.age
buff.count = buff.count + 1
buff
}
// 合并缓冲区
override def merge(b1: Any, b2: Any): Buff = {
buff1.total = buff1.total + buff2.total
buff1.count = buff.count + buff2.count
buff1
}
// 计算结果
override def finish(reduction: Any): Long = {
buff.total / buff.count
}
// 缓冲区的编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
// 输出的编码操作
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
val df = spark.read.json("datas/user.json")
// 早期版本中,spark不能在sql中使用强类型UDAF操作
// 早期的UDAF强类型聚合函数使用DSL语法操作
val ds: Dataset[User] = df.as[User]
// 将UDAF函数转换为查询的列对象
val udafCol: TypedColumn[User, Long] = new MyAvgUDAF().toColumn
ds.select(udafCol).show()
强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数,如count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。