SparkSQL 自定义函数

1. UDF

在 Hive 中,我们实现的 UDF 必须将方法命名为 evaluate ,而 Spark SQL 中却没有这么无理的要求,我们可以根据所需随意自定义函数。

语法格式:

spark.udf.register(函数名,函数体)

🌰 将日期变化格式:

原数据 birthday.txt 预览:

Michael, 2020/Nov/12 15:34:56
Andy, 2020/Dec/05 17:27:38
Justin, 2020/Dec/27 22:48:23

程序实现:

def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
    .appName("UDF")
    .master("local[2]")
    .getOrCreate()

    val sc = SparkContext.getOrCreate()

    import spark.implicits._

    // 加载数据源,将其转化为 DataFrame
    var df = sc.textFile("birthday.txt")
    .map(_.split(","))
    .map(line => (line(0), line(1)))
    .toDF("name", "birthday") // 转型时指定字段的名称


    // 自定义函数的实现
    spark.udf.register("TranBirth", (dt: String) => {
        // 日期的输入格式(US)
        val parser = new SimpleDateFormat("yyyy/MMM/dd HH:mm:ss", Locale.US)
        // 日期的输出格式
        val formatter = new SimpleDateFormat("dd-MM-yyyy HH:mm:ss")
        // 将输入日期转型
        formatter.format(parser.parse(dt))
    })

    // 建立临时视图
    df.createOrReplaceTempView("birthday")

    // SQL 语句中使用自定义函数
    spark.sql("select name, TranBirth(birthday.birthday) from birthday").show()

}

输出:

+-------+-------------------+
|   name|TranBirth(birthday)|
+-------+-------------------+
|Michael|12-11-2020 15:34:56|
|   Andy|05-12-2020 17:27:38|
| Justin|27-12-2020 22:48:23|
+-------+-------------------+

2. UDAF

强类型的 DataSet 和弱类型的 DataFrame 都提供了相关的聚合函数,如 count()countDistinct()avg()min() 等。除此之外,用于可以设定自己的聚合函数,通过继承 UserDefinedAggregateFunction 实现用户自定义弱类型函数,自 Spark 3.0 之后,UserDefinedAggregateFunction 已不推荐使用了,可以统一采用强类型聚合函数 Aggergator

2.1 RDD 实现

🌰实例:计算平均工资

val rdd = sc.makeRDD(List(("Michael", 3000),("Andy", 3300), ("Justin", 4500)))
  .map{
    case(name, age) => (age, 1)
  }
  .reduce((t1, t2) => (t1._1 + t2._1 , t1._2 +  t2._2))
println(rdd._1 / rdd._2 * 1.0)  // 输出: 3600.0

2.2 UDAF 弱类型实现

🌰实例:计算平均工资

数据预览 employees.json

{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}

自定义类,继承 UserDefinedAggregateFunction 并实现其中的方法。

class AverageUDAF extends UserDefinedAggregateFunction {

    // 聚合函数输入参数的数据类型
    override def inputSchema: StructType = 
    StructType(Array(StructField("salary", IntegerType)))

    // 聚合函数缓冲区中值的数据类型(age,count)
    override def bufferSchema: StructType =
    StructType(Array(StructField("sum", LongType), StructField("count", LongType)))

    // 函数返回值的数据类型
    override def dataType: DataType = DoubleType

    // 稳定性:对于相同的输入是否一直返回相同的输出。
    override def deterministic: Boolean = true

    // 函数缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0L
        buffer(1) = 0L
    }

    // 更新缓冲区中的数据
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if (!input.isNullAt(0)) {
            buffer(0) = buffer.getLong(0) + input.getInt(0)
            buffer(1) = buffer.getLong(1) + 1
        }
    }

    // 合并缓冲区
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(0) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    // 计算最终结果
    override def evaluate(buffer: Row): Double = 
    buffer.getLong(0) / buffer.getLong(1) * 1.0
}
spark.udf.register("AverageUDAF", new AverageUDAF)
val df = spark.read.json("employees.json")

df.createOrReplaceTempView("employees")

spark.sql("select name, AverageUDAF(salary) from employees").show()

2.3 UDAF 强类型

🌰实例:计算平均工资

数据预览 employees.json

{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}

自定义类,继承 Aggregator 并实现其中的方法。

import org.apache.spark.sql.expressions.Aggregator

// 输入数据类型
case class Emp(name: String, salary: Long)

// 缓冲数据类型
case class AvgBuffer(var sum: Long, var count: Long)

class AgeUDAF extends Aggregator[Emp, AvgBuffer, Double] {

    override def zero: AvgBuffer = AvgBuffer(0L, 0L)

    override def reduce(b: AvgBuffer, a: Emp): AvgBuffer = {
        b.sum = b.sum + a.salary
        b.count += 1
        b
    }

    override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
    }

    override def finish(reduction: AvgBuffer): Double =
    reduction.sum.toDouble / reduction.count

    override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
spark.udf.register("AgeUDAF", functions.udaf(new AgeUDAF))
val df = spark.read.json("employees.json")

df.createOrReplaceTempView("employees")

spark.sql("select AgeUDAF(salary) from employees").show()

 


❤️ END ❤️
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

JOEL-T99

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值