SparkSql-自定义函数

1.UDF

现有数据的字段包括username和age,要求查询时在username的结果前加上字符串name:,如name:张三
  代码如下:

def main(args: Array[String]): Unit = {
	
	//创建上下文环境配置对象
	val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
	//创建 SparkSession 对象
	val spark = SparkSession.builder().config(sparkConf).getOrCreate()
	
	val df: DataFrame = spark.read.json("datas/user.json")
	//创建临时表
	df.createOrReplaceTempView("user")
	//注册udf
	spark.udf.register("prefix",(username:String) => "name:"+username)
	//应用udf
	spark.sql("select prefix(username),age from user").show
	
	spark.stop()
}

2.UDAF

  强类型的 Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数Aggregator。
需求:计算平均值的聚合函数

2.1 UDAF-弱类型

  • 1.自定义聚合类
class MyAveragUDAF extends UserDefinedAggregateFunction {
    //输入数据的结构
    override def inputSchema: StructType = {
        StructType(
            Array(
                StructField("age",LongType)
            )
        )
    }
    //聚合缓冲区数据的结构
    override def bufferSchema: StructType = {
        StructType(
            Array(
                StructField("sum",LongType),
                StructField("count",LongType)
            )
        )
    }
    //输出数据的结构
    override def dataType: DataType = LongType
    //稳定性
    override def deterministic: Boolean = true
    //初始化缓冲区的数据
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer.update(0,0L)
        buffer.update(1,0L)
    }
    //遍历传入数据的时候更新缓冲区的方法
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer.update(0,buffer.getLong(0) + input.getLong(0))
        buffer.update(1,buffer.getLong(1) + 1)
    }
    //缓冲区数据合并的数据,更新buffer1的数据
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
        buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
    }
    //计算返回最后的结果
    override def evaluate(buffer: Row): Any = {
        buffer.getLong(0)/buffer.getLong(1)
    }
}
  • 2.使用该聚合函数
def main(args: Array[String]): Unit = {
    //创建上下文环境配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
    //创建 SparkSession 对象
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    //注册自定义类
    spark.udf.register("ageAvg",new MyAveragUDAF())
    //创建DataFrame
    val df: DataFrame = spark.read.json("datas/user.json")
    //创建临时表
    df.createOrReplaceTempView("user")
    //使用自定义函数
    spark.sql("select ageAvg(age) from user").show

    spark.stop()
}
  • 3.结果展示
原始数据:
+---+--------+
|age|username|
+---+--------+
| 18|    张三|
| 19|    李四|
| 20|    王五|
+---+--------+
查询结果:
+-----------------+
|myaveragudaf(age)|
+-----------------+
|               19|
+-----------------+

2.1 UDAF-强类型

弱类型需要通过数据的顺序通过下标索引的方式操作数据,容易出错,强类型可以通过类属性的方式访问数据。

  • 1.自定义聚合类
/*
   *1.继承package org.apache.spark.sql.expressions.Aggregator,定义范性
       IN:输入数据类型
       BUF:缓冲区数据类型
       OUT:输出数据类型
   *2.重写6个方法
*/

case class Buff(var sum:Long,var count:Long)
class MyAveragUDAF extends Aggregator[Long,Buff,Long] {
   //缓冲区初始化
   override def zero: Buff = {
       Buff(0L,0L)
   }
   //根据输入的数据更新缓冲区
   override def reduce(buff: Buff, in: Long): Buff = {
       buff.sum = buff.sum + in
       buff.count = buff.count + 1
       buff
   }
   //合并数据
   override def merge(buff1: Buff, buff2: Buff): Buff = {
       buff1.sum = buff1.sum + buff2.sum
       buff1.count = buff1.count + buff2.count
       buff1
   }
   //计算结果
   override def finish(reduction: Buff): Long = {
       reduction.sum / reduction.count
   }
   //编码,自定义类和Scala自带类固定写法
   override def bufferEncoder: Encoder[Buff] = Encoders.product
   //解码,自定义类和Scala自带类固定写法
   override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
  • 2.使用聚合类
def main(args: Array[String]): Unit = {
    //创建上下文环境配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
    //创建 SparkSession 对象
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    //注册自定义类
    spark.udf.register("ageAvg", functions.udaf(new MyAveragUDAF()))
    //创建DataFrame
    val df: DataFrame = spark.read.json("datas/user.json")
    //创建临时表
    df.createOrReplaceTempView("user")
    //使用自定义函数
    spark.sql("select ageAvg(age) from user").show

    spark.stop()
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值