spark自定义UDAF函数

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}


object UserDefineFunction3 extends App {
  private val session: SparkSession = SparkSession.builder().appName("test").master("local").getOrCreate()
  //定义一个List
  val list = List(
    Score(1, "zs", 99.9),
    Score(2, "ls", 89.2),
    Score(3, "xm", 29.3),
    Score(4, "wx", 79.0),
    Score(5, "lq", 69.9),
    Score(6, "zl", 49.5)
  )

  //此处是session是SparkSession对象
  import session.implicits._
  //list转DataSet对象
  val ds = list.toDS()
  //创建临时视图
  ds.createTempView("tmp")
  //注册自定义函数
  session.udf.register("myavg", new MyAvgUDAF)
  //定义sql语句
  val sql =
    """
      |select avg(score),
      |myavg(score),
      |name
      |from tmp
      |group by name
      |""".stripMargin
 //执行sql
  session.sql(sql).show()
  session.stop()

}

case class Score(id: Int, name: String, score: Double)
//自定义UDAF类,继承UserDefinedAggregateFunction类
class MyAvgUDAF extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = {
    StructType(Array(StructField("score", DataTypes.DoubleType)))
  }

  //临时变量的数据类型
  override def bufferSchema: StructType = {
    StructType(Array(StructField("sum", DataTypes.DoubleType),
      StructField("count", DataTypes.IntegerType)))
  }

  override def dataType: DataType = DataTypes.DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //临时变量是存储在buffer中的,我们定义的buffer有两个元素,第一个是sum,第二个是count
    //设置第一个元素的初始值.0d 代表double类型
    buffer.update(0, 0d)
    //设置第二个元素的初始值 count的初始值为0
    buffer.update(1, 0)

  }

  //分区类累加操作
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val score: Double = input.getAs(0)
    //将新进来的行中的分数累加到第一个元素sum上
    buffer.update(0, buffer.getDouble(0) + score)
    //将第二个元素count累加1
    buffer.update(1, buffer.getInt(1) + 1)

  }

  //分区间累加操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))
    buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
  }

  //计算输出结果
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0) / buffer.getInt(1)
  }
}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值