sparkSql自定义聚合函数

 自定义聚合函数分两个类型,一个是强类型的,需要用DSL语句,另一个就是下面这种

import java.lang
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.Properties

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.{JdbcRDD, RDD}
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object test11 {
    def main(args: Array[String]): Unit = {

        val spark: SparkSession = SparkSession.builder().master("local[*]").appName("haha").getOrCreate()
        import spark.implicits._
        val ssc: SparkContext = spark.sparkContext
        // 生成DF
        val userDF: DataFrame = spark.read.json("input/user.json")
        // 生成视图
        userDF.createOrReplaceTempView("user")
        // 生成自定义聚合函数
        val udaf = new MyAvg
        // 注册自定义聚合函数
        spark.udf.register("ageAvg",udaf)
        spark.sql("select ageAvg(age) from user").show()
        spark.stop()
    }
}
class MyAvg extends UserDefinedAggregateFunction {
    // 表示聚合函数的输入的数据结构
    override def inputSchema: StructType = {
        StructType(Array(StructField("age",LongType)))
    }
    // 表示聚合运算时缓冲区的数据结构
    override def bufferSchema: StructType = {
        StructType(Array(StructField("total",LongType),(StructField("count",LongType))))
    }
    // 聚合函数的结果类型
    override def dataType: DataType = {
        DoubleType
    }
    // 稳定性
    override def deterministic: Boolean = true
    // 聚合函数的初始化(缓冲区的初始化)
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0L
        buffer(1) = 0L
    }
    // 缓冲区内更新
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1L
    }
    // 缓冲区之间的聚合
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }
    // 计算聚合函数的结果
    override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble / buffer.getLong(1)
    }
}

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值