详解 Spark SQL 代码开发之用户自定义函数

一、UDF

一进一出函数

/**
	语法:SparkSession.udf.register(func_name: String, op: T => K)
*/
object TestSparkSqlUdf {
    def main(args: Array[String]): Unit = {
        // 创建 sparksql 环境对象
        val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        
        // 引入环境对象中的隐式转换
        import spark.implicits._
        
        val df: DataFrame = spark.read.json("data/user.json")
        /*
        	需求:给 username 字段的每个值添加前缀
        */
        spark.udf.register("prefixName", name => "Name: " + name)
        
        df.createOrReplaceTempView("user")
        
        spark.sql("select prefixName(username), age from user").show()
        
        // 关闭环境
        spark.close()
    }
    
}

二、UDAF

多进一出函数,即聚合函数

1. 弱类型函数

/**
	自定义步骤:
		1.继承 UserDefinedAggregateFunction 抽象类(已过时)
		2.重写 8 个方法
*/
object TestSparkSqlUdaf {
    def main(args: Array[String]): Unit = {
        // 创建 sparksql 环境对象
        val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        
        // 引入环境对象中的隐式转换
        import spark.implicits._
        
        val df: DataFrame = spark.read.json("data/user.json")
        /*
        	需求:自定义求年龄平均值的udaf函数
        */
        val myAvgUdaf = new MyAvgUdaf()
        spark.udf.register("ageAvg", myAvgUdaf)
        
        df.createOrReplaceTempView("user")
        
        spark.sql("select ageAvg(age) from user").show()
        
        // 关闭环境
        spark.close()
    }
    
}

// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends UserDefinedAggregateFunction {
    
    // 输入数据的结构类型
    def inputSchema: StructType = {
        // StructType 是样例类
        StructType(Array(
        	// StructField 是样例类,必传参数 name: String, dataType: DataType
            StructField("age", LongType)
        ))
    } 
    
    // 缓冲区的结构类型
    def bufferSchema: StructType = {
        StructType(Array(
        	StructField("totalAge", LongType),
            StructField("count", LongType)
        ))
    }
    
    // 输出数据的结构类型
    def dataType: DataType = DoubleType
    
    // 函数稳定性
    def deterministic: Boolean = true
    
    // 缓冲区初始化
    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0, 0L)
        buffer.update(1, 0L)
    }
    
    // 接收输入数据更新缓冲区数据
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val totalAge = buffer.getLong(0)
        val count = buffer.getLong(1)
        val age = input.getLong(0)
        
        buffer.update(0, totalAge + age)
        buffer.update(1, count + 1)
    }
    
    // 合并缓冲区
    def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = {
        buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
        buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
    }
    
    // 计算最终结果
    def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble/buffer.getLong(1)
    }
    
}

2. 强类型函数

2.1 Spark3.0 之前
/**
	自定义步骤:
		1.继承 Aggregator 抽象类,定义泛型
			IN:输入数据类型
			BUF:缓冲区类型
			OUT:输出数据类型
		2.重写 6 个方法
*/
object TestSparkSqlUdaf1 {
    def main(args: Array[String]): Unit = {
        // 创建 sparksql 环境对象
        val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        
        // 引入环境对象中的隐式转换
        import spark.implicits._
        
        val df: DataFrame = spark.read.json("data/user.json")
        /*
        	需求:自定义求年龄平均值的udaf函数
        */
        // Spark3.0 之前的强类型UDAF函数必须在 DSL 语法中使用
        val ds = df.as[User]
        
        // 将UDAF函数对象转换成 DSL 语法中的查询列
        val col: TypedColumn[User, Double] = new MyAvgUdaf().toColumn
        
        ds.select(col).show()
        
        // 关闭环境
        spark.close()
    }
    
}

// 定义封装输入的一行数据的类
case class User(username: String, age: Long)

// 定义缓冲区类
case class Buff(var totalAge: Long, var count: Long)

// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends Aggregator[User, Buff, Long] {
    // 缓冲区初始化
    override def zero: Buff = Buff(0L, 0L)
    
    // 根据输入数据更新缓冲区数据
    override def reduce(buff: Buff, in: User): Buff = {
        buff.totalAge = buff.totalAge + in.age
        buff.count = buff.count + 1
        buff
    }
    
    // 合并缓冲区
    override def merge(buff1: Buff, buff2: Buff): Buff = {
        buff1.totalAge = buff1.totalAge + buff2.totalAge
        buff1.count = buff1.count + buff2.count
        buff1
    }
    
    // 计算最终结果
    override def finish(buff: Buff): Double = {
        buff.totalAge.toDouble/buff.count
    } 
    
    //DataSet 默认的编解码器,用于序列化,固定写法
    //自定义类型是 product 
    // 缓冲区编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product
    
    // 输出数据编码操作
    // 自带类型根据类型选择
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    
}
2.2 Spark3.0 之后
/**
	自定义步骤:
		1.继承 Aggregator 抽象类,定义泛型
			IN:输入数据类型
			BUF:缓冲区类型
			OUT:输出数据类型
		2.重写 6 个方法
*/
object TestSparkSqlUdaf1 {
    def main(args: Array[String]): Unit = {
        // 创建 sparksql 环境对象
        val conf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        
        // 引入环境对象中的隐式转换
        import spark.implicits._
        
        val df: DataFrame = spark.read.json("data/user.json")
        /*
        	需求:自定义求年龄平均值的udaf函数
        */
        // Spark3.0 之后的强类型UDAF可以在 SQL 语法中使用
        val myAvgUdaf = new MyAvgUdaf()
        
        // 注册函数时需要使用 functions.udaf(func) 包装转换
        spark.udf.register("ageAvg", functions.udaf(myAvgUdaf))
        
        df.createOrReplaceTempView("user")
        
        spark.sql("select ageAvg(age) from user").show()
        
        // 关闭环境
        spark.close()
    }
    
}

// 定义缓冲区类
case class Buff(var totalAge: Long, var count: Long)

// 自定义聚合函数类,实现求年龄平均值
class MyAvgUdaf extends Aggregator[Long, Buff, Long] {
    // 缓冲区初始化
    override def zero: Buff = Buff(0L, 0L)
    
    // 根据输入数据更新缓冲区数据
    override def reduce(buff: Buff, in: Long): Buff = {
        buff.totalAge = buff.totalAge + in
        buff.count = buff.count + 1
        buff
    }
    
    // 合并缓冲区
    override def merge(buff1: Buff, buff2: Buff): Buff = {
        buff1.totalAge = buff1.totalAge + buff2.totalAge
        buff1.count = buff1.count + buff2.count
        buff1
    }
    
    // 计算最终结果
    override def finish(buff: Buff): Double = {
        buff.totalAge.toDouble/buff.count
    } 
    
    //DataSet 默认的编解码器,用于序列化,固定写法
    //自定义类型是 product 
    // 缓冲区编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product
    
    // 输出数据编码操作
    // 自带类型根据类型选择
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    
}
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值