大数据系列篇-SPARK-SQL用户定义聚合函数

16 篇文章 0 订阅
9 篇文章 0 订阅

大数据系列篇-SPARK-SQL用户定义聚合函数

package com.test

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
import org.apache.spark.sql._

//用户定义聚合函数
object SparkSqlUdaf {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf().setAppName("练习SparkSqlUdaf").setMaster("local[*]")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate() //使用$转换时

    val df = spark.read.json("data/user.json")
    df.createOrReplaceTempView("user")

    //弱类型
    spark.udf.register("avgAge1", new UdafAvg1)
    spark.sql("SELECT avgAge1(age) as avgAge1 FROM user").show()

    //强类型
    spark.udf.register("avgAge2", functions.udaf(new UdafAvg2))
    spark.sql("SELECT avgAge2(age) as avgAge2 FROM user").show()

    //低版本中UDAF使用DSL的方式
    import spark.implicits._
    val ds = df.as[User]
    val column = new UdafAvg3().toColumn //udaf函数转换列查询
    ds.select(column).show()

    spark.close()
  }


  //(弱类型-按位置)
  class UdafAvg1 extends UserDefinedAggregateFunction {
    //输入
    override def inputSchema: StructType = {
      StructType(
        Array(
          StructField("age", LongType)
        )
      )
    }

    //缓冲区
    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total", LongType),
          StructField("count", LongType),
        )
      )
    }

    //输出
    override def dataType: DataType = LongType

    //函数稳定度
    override def deterministic: Boolean = true

    //缓冲区initialize  按位置
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L

      // 或
      // buffer.update(0,0L)
      // buffer.update(1,0L)
    }

    //缓冲区update 按位置
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0, buffer.getLong(0) + input.getLong(0))
      buffer.update(1, buffer.getLong(1) + 1)
    }

    //缓冲区merge
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
      buffer1.update(1, buffer1.getLong(1) + +buffer2.getLong(1))
    }

    // 计算
    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0) / buffer.getLong(1)
    }
  }

  //强类型
  class UdafAvg2 extends Aggregator[Long, /*输入类型*/ UserBuf, /*缓冲类型*/ Long /*输出类型*/ ] {
    //缓冲区initialize
    override def zero: UserBuf = {
      UserBuf(0L, 0)
    }

    //按输入更新缓冲区
    override def reduce(b: UserBuf, a: Long): UserBuf = {
      b.count += 1
      b.total += a
      b
    }

    //缓冲区merge
    override def merge(b1: UserBuf, b2: UserBuf): UserBuf = {
      b1.count += b2.count
      b1.total += b2.total
      b1
    }

    //计算
    override def finish(reduction: UserBuf): Long = {
      reduction.total / reduction.count
    }

    //缓冲区Encoder
    override def bufferEncoder: Encoder[UserBuf] = Encoders.product //定义的类用Encoders.product

    //输出Encoder
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }

  case class UserBuf(var total: Long, var count: Long)

  //强类型
  class UdafAvg3 extends Aggregator[User, /*输入类型*/ UserBuf, /*缓冲类型*/ Long /*输出类型*/ ] {
    //缓冲区initialize
    override def zero: UserBuf = {
      UserBuf(0L, 0)
    }

    //按输入更新缓冲区
    override def reduce(b: UserBuf, a: User): UserBuf = {
      b.count += 1
      b.total += a.age
      b
    }

    //缓冲区merge
    override def merge(b1: UserBuf, b2: UserBuf): UserBuf = {
      b1.count += b2.count
      b1.total += b2.total
      b1
    }

    //计算
    override def finish(reduction: UserBuf): Long = {
      reduction.total / reduction.count
    }

    //缓冲区Encoder
    override def bufferEncoder: Encoder[UserBuf] = Encoders.product //定义的类用Encoders.product

    //输出Encoder
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }

  case class User(var userName: String, var age: Long)

}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值