大数据技术之Spark SQL——UDF/UDAF函数

自定义函数UDF

val conf = new SparkConf().setMaster("local").setAppName("UDF")
val spark = SparkSession.builder().config(conf).getOrCreate()

import spark.implicits._

val df = spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")

spark.udf.register("prefixName", (name:String) => {
    "Name: " + name
})

spark.sql("select age, prefixName(username) from user").show

spark.close()

自定义聚合函数UDAF

——弱类型函数实现

val conf = new SparkConf().setMaster("local").setAppName("UDF")
val spark = SparkSession.builder().config(conf).getOrCreate()

import spark.implicits._

/*
    自定义聚合函数类:计算年龄平均值
    1. 继承UserDefineAggregateFunction
    2. 重写方法
*/

/* UserDefineAggregateFunction默认已经不推荐使用 */

class MyAvgUDAF extends UserDefineAggregateFunction {
    // 输入数据的结构
    override def inputSchema: StructType = {
        StructType(
            Array(
                StructField("age", LongType)
            )
        )
    }

    // 缓冲区数据的结构:Buffer
    override def bufferSchema: StructType = {
        StructType(
            Array(
                StructField("total", LongType),
                StructField("count", LongType)
            )
        )
    }

    // 函数计算结果的数据类型:Out
    override def dataType: DataType = LongType

    // 函数的稳定性
    override def deterministic: Booleaj = true

    // 缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0,0L)
        buffer.update(1,0L)
    }

    // 根据输入的值更新缓冲区数据
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer.update(0, buffer.getLong(0) + input.getLong(0))
        buffer.update(1, buffer.getLong(1) + 1)
    }

    // 缓冲区数据合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
        buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
    }

    // 计算平均值
    override def evaluate(buffer: Row): Any = {
        buffer.getLong(0) / buffer.getLong(1)
    }
}



spark.close()

 弱类型的函数没有类型的概念,只能根据传参的顺序来操作,这样在使用上容易出错。

强类型可以通过属性的方式实现,这样在实现时就不容易出错。

——强类型函数实现

/*
    自定义聚合函数类:计算年龄的平均值
    1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
        IN:输入的数据类型 User
        BUF:缓冲区的数据类型 Buff
        OUT:输出的数据类型 User
    2. 重写方法(6)
*/
case class User(username: String, age: Long)

case class Buff(var total: Long, var count: Long)

class MyAvgUDAF extends Aggregator[]{

    // z & zero : 初始值或零值
    // 缓冲区的初始化
    override def zero: Buff = {
        Buff(0L, 0L)
    }

    // 根据输入的数据更新缓冲区的数据
    override def reduce(b: Any, a: User): Buff = {
        buff.total = buff.total + in.age
        buff.count = buff.count + 1
        buff
    }

    // 合并缓冲区
    override def merge(b1: Any, b2: Any): Buff = {
        buff1.total = buff1.total + buff2.total
        buff1.count = buff.count + buff2.count
        buff1
    }

    // 计算结果
    override def finish(reduction: Any): Long = {
        buff.total / buff.count
    }

    // 缓冲区的编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product

    // 输出的编码操作
    override def outputEncoder: Encoder[Long] =  Encoders.scalaLong
}


val df = spark.read.json("datas/user.json")

// 早期版本中,spark不能在sql中使用强类型UDAF操作
// 早期的UDAF强类型聚合函数使用DSL语法操作
val ds: Dataset[User] = df.as[User]

// 将UDAF函数转换为查询的列对象
val udafCol: TypedColumn[User, Long] = new MyAvgUDAF().toColumn

ds.select(udafCol).show()

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数,如count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值