在Kotlin中使用Spark SQL的UDF和UDAF函数

1. 项目结构与依赖

1.1 项目依赖

使用gradle:

在项目的build.gradle.kts添加

dependencies {
    implementation("org.apache.spark:spark-sql_2.12:3.3.1")
}

使用maven:

在模块的pom.xml中添加

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.12</artifactId>
            <version>3.3.1</version>
        </dependency>

2. UDF的使用与实现

UDF,即用户自定义函数,允许用户在SQL查询中使用自定义的函数。下面案例做了一个简单的案例,将首字母变为大写。

2.1 数据源

准备数据源使用JSON数据作为数据格式,保存到项目的根路径下的`data/user.txt`文件。

{"name":"zhangsan","age":19,"gender":"boy"}
{"name":"lisi","age":20,"gender":"boy"}
{"name":"wangwu","age":21,"gender":"boy"}
{"name":"zhaoliu","age":22,"gender":"boy"}
{"name":"sunqi","age":23,"gender":"boy"}
{"name":"zhouba","age":24,"gender":"boy"}
{"name":"wujiu","age":25,"gender":"boy"}
{"name":"zhengshi","age":26,"gender":"boy"}

2.2 代码示例

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.types.DataTypes
import java.util.*

class SparkSQL_UDF {
    fun f1() {
        val sparkSession = SparkSession.builder()
            .master("local")
            .appName("Kotlin Spark UDF")
            .orCreate

        // 读取JSON数据并创建视图
        sparkSession.read().json("data/user.txt")
            .createOrReplaceTempView("user")
        
        // 注册UDF函数,将名字的首字母大写
        sparkSession.udf().register("nameHeaderUpper",
            UDF1 { name: String ->
                name.substring(0, 1).uppercase(Locale.getDefault()) + name.substring(1)
            },
            DataTypes.StringType)

        // 使用注册的UDF函数进行SQL查询
        sparkSession.sql("select nameHeaderUpper(name) as name from user").show()

        sparkSession.stop()
    }

    companion object {
        @JvmStatic
        fun main(args: Array<String>) {
            SparkSQL_UDF().f1()
        }
    }
}

/*输出结果
+---------------------+
|nameHeaderUpper(name)|
+---------------------+
|             Zhangsan|
|                 Lisi|
|               Wangwu|
|              Zhaoliu|
|                Sunqi|
|               Zhouba|
|                Wujiu|
|             Zhengshi|
+---------------------+
*/

2.3 代码解析

1. 创建SparkSession:

使用local模式进行测试

   val sparkSession = SparkSession.builder()
       .master("local")
       .appName("Kotlin Spark UDF")
       .orCreate
2. 读取数据并创建视图:

创建名为user的视图

   sparkSession.read().json("data/user.txt")
       .createOrReplaceTempView("user")
 3. 注册UDF函数:

使用kotlin的lambda表达式来简化UDF函数的创建 要注意这里是Java中的UDF1{}而不是Scala中的Function1{}

   sparkSession.udf().register("nameHeaderUpper",
       UDF1 { name: String ->
           name.substring(0, 1).uppercase(Locale.getDefault()) + name.substring(1)
       },
       DataTypes.StringType)
4. 执行SQL查询:

用sparkSession对象调用sql方法进行查询 并将结果展示到控制台

  sparkSession.sql("select nameHeaderUpper(name) as name from user").show()
5. 停止SparkSession:

释放资源
调用close和stop方法都可以

   sparkSession.stop()

3. UDAF的使用与实现

UDAF,即用户自定义聚合函数,允许用户定义复杂的聚合逻辑,如求平均值、总和等。

下面是一个简单案例来实现求年龄的平均值

3.1 代码示例

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
import java.io.Serializable

class SparkSQL_UDAF {
    fun f1() {
        val sparkSession = SparkSession.builder()
            .master("local")
            .appName("Kotlin Spark UDAF")
            .orCreate

        // 定义聚合器
        val agg = object : Aggregator<Long, Buffer, Long>(){
            override fun reduce(b: Buffer?, a: Long?): Buffer {
                val updatedBuffer = b ?: Buffer(0, 0)
                updatedBuffer.cnt++
                updatedBuffer.count += a!!
                return updatedBuffer
            }

            override fun outputEncoder(): Encoder<Long> {
                return Encoders.LONG()
            }

            override fun zero(): Buffer {
                return Buffer(0L, 0L)
            }

            override fun bufferEncoder(): Encoder<Buffer> {
                return Encoders.bean(Buffer::class.java)
            }

            override fun finish(reduction: Buffer?): Long {
                return reduction?.count?.div(reduction.cnt) ?: 0L
            }

            override fun merge(b1: Buffer?, b2: Buffer?): Buffer {
                return Buffer(
                    count = (b1?.count ?: 0) + (b2?.count ?: 0),
                    cnt = (b1?.cnt ?: 0) + (b2?.cnt ?: 0)
                )
            }
        }

        // 注册UDAF函数
        sparkSession.udf().register("avgAge", functions.udaf(agg, Encoders.LONG()))

        // 使用注册的UDAF函数进行SQL查询
        sparkSession.read().json("data/user.txt").createOrReplaceTempView("user")
        sparkSession.sql("select avgAge(age) as avg_age from user").show()

        sparkSession.stop()
    }

    companion object {
        @JvmStatic
        fun main(args: Array<String>) {
            SparkSQL_UDAF().f1()
        }
    }
}

data class Buffer(var count: Long, var cnt: Long) : Serializable {
    constructor() : this(0, 0)
}

/*输出结果
+-----------+
|avgage(age)|
+-----------+
|         22|
+-----------+
*/

3.2 代码解析

1. Buffer数据类:

用来缓存聚合的中间过程的类 这里不能使用Scala中的二元组和kotlin中的Pair类 因为都不能修改其中的数据 而且需注意 需要有空参构造 (Kotlin中的data class 默认没有空参构造)
 

   data class Buffer(var count: Long, var cnt: Long) : Serializable {
       constructor() : this(0, 0)
   }
2. 定义UDAF函数:

为了看着清晰 创建了一个匿名内部类对象agg 继承自Aggregator 需要定义输入输出和缓存的泛型

实现其对应的方法

val agg = object : Aggregator<Long, Buffer, Long>(){
       override fun reduce(b: Buffer?, a: Long?): Buffer {
           val updatedBuffer = b ?: Buffer(0, 0)
           updatedBuffer.cnt++
           updatedBuffer.count += a!!
           return updatedBuffer
       }

       override fun outputEncoder(): Encoder<Long> {
           return Encoders.LONG()
       }

       override fun zero(): Buffer {
           return Buffer(0L, 0L)
       }

       override fun bufferEncoder(): Encoder<Buffer> {
           return Encoders.bean(Buffer::class.java)
       }

       override fun finish(reduction: Buffer?): Long {
           return reduction?.count?.div(reduction.cnt) ?: 0L
       }

       override fun merge(b1: Buffer?, b2: Buffer?): Buffer {
           return Buffer(
               count = (b1?.count ?: 0) + (b2?.count ?: 0),
               cnt = (b1?.cnt ?: 0) + (b2?.cnt ?: 0)
           )
       }
   }

其中 reduce是对输入的数据进行聚合(这里是累加)

outputEncoder和bufferEncoder是将输出和缓存的结果进行序列化

zero是对数据赋初值

merge是 多分区的数据进行合并时调用的方法 用于将缓存数据合并 (这里时Buffer类)

finish是输出最终结果

要十分注意在这些方法中的空安全处理(有时候kt的空安全挺烦人的) 不要轻易用非空断言

3. 注册UDAF函数:

使用functions.udaf()注册 要指明输出结果的类型

   sparkSession.udf().register("avgAge", functions.udaf(agg, Encoders.LONG()))
4. 执行SQL查询:

将读入的数据注册为视图 调用sparkSession的sql方法查询 并将结果在控制台打印输出

   sparkSession.read().json("data/user.txt").createOrReplaceTempView("user")
   sparkSession.sql("select avgAge(age) as avg_age from user").show()

5.释放资源:

调用stop过close方法释放资源

   sparkSession.stop()

4. 总结

在不适用spark kotlin api的情况下 用kotlin来写SparkSql基本是使用spark提供的Java api 因为kt与Scala不是完全兼容 所以要注意其中的一些高阶函数还有元组的使用和序列化等问题 

  • 17
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值