spark SQL UDF和UDAF

UDF

  • UDF(User Define Function)
    spark内置的函数不能满足我们的要求的时候,我们通常需要自定义函数来实现我们的需求。
  • 示例
object UDF {
    def main(args: Array[String]): Unit = {
        val sparkSession = SparkSession.builder()
                .appName("UDF")
                .master("local[2]")
                .getOrCreate()
                
        // 创建一个RDD
        val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "Alan")
        val namesRDD = sparkSession.sparkContext.parallelize(names)
        
        // 转为DF
        // 方法一:动态加载
        val namesRowRDD = namesRDD.map(x => Row(x))
        val schema = StructType(Array(
            StructField("name", StringType, true)
        ))
        val namesDF = sparkSession.createDataFrame(namesRowRDD, schema)
        
        // 方法二:反射
        import sparkSession.implicits._
        val namesDF = namesRDD.toDF("name")
        
        namesDF.createOrReplaceTempView("udfTest")
        // 注册UDF函数
        sparkSession.udf.register("strLength", (str:String) =>str.length)
        sparkSession.sql("select name, strLength(name) length from udfTest").show()
    }
}
+--------+------+
|    name|length|
+--------+------+
|zhangsan|     8|
|    lisi|     4|
|  wangwu|     6|
|     Tom|     3|
|   Jerry|     5|
|    Alan|     4|
+--------+------+

UDAF

  • UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值按指定方法聚合一下。
  • UDAF的使用有两种方式:1.继承UserDefinedAggregateFunction 2.继承Aggregator(优点是可以带类型)

使用UDAF一般步骤:

  1. 自定义类继承UserDefinedAggregateFunction或者Aggregator,对每个阶段方法做实现
  2. 在sparkSession中注册UDAF,为其绑定一个名字
  3. 在sql语句中使用上面绑定的名字调用
  • 继承UserDefinedAggregateFunction
class UDAFStringCount extends UserDefinedAggregateFunction {
    // 输入数据的类型
    override def inputSchema: StructType = {
        StructType(Array(StructField("str", StringType, true)))
    }
    // 中间聚合时所处理的数据
    override def bufferSchema: StructType = {
        StructType(Array(StructField("count", IntegerType, true)))
    }
    // 函数返回的类型
    override def dataType: DataType = {
        IntegerType
    }
    // 指定是否是确定性的
    override def deterministic: Boolean = {
        true
    }

    // 为每个分组的数据执行初始化操作
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0
    }
    // 每个分组有新值过来,如何进行分组对应的聚合值的计算
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getAs[Int](0) + 1
    }
    // 合并,一个分组的数据会分布在多个节点上处理,所以最后要用merge
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
    }
    // 通过中间的缓存聚合值,最后返回一个最终的聚合值
    override def evaluate(buffer: Row): Any = {
        buffer.getAs[Int](0)
    }
}
  • 注册udaf,sql调用
object UDAF {
    def main(args: Array[String]): Unit = {
        //UDAF可以针对多行输入,进行聚合计算,返回一个输出
        val sparkSession = SparkSession.builder()
                .appName("UDAF")
                .master("local[2]")
                .getOrCreate()

        val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "zhangsan", "Tom", "zhangsan",
            "lisi", "wangwu", "Tom", "Jerry", "Alan")

        val namesRDD = sparkSession.sparkContext.parallelize(names)

        import sparkSession.implicits._
        val namesDF = namesRDD.toDF("name")

        namesDF.createOrReplaceTempView("udafTest")

        sparkSession.udf.register("strCount", new UDAFStringCount)

        sparkSession.sql("select name, strCount(name) len from udafTest group by name").show()
    }
}
+--------+---+
|    name|len|
+--------+---+
|  wangwu|  2|
|     Tom|  3|
|   Jerry|  2|
|zhangsan|  3|
|    Alan|  1|
|    lisi|  2|
+--------+---+
  • 继承Aggregator
在这里插入代码片
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值