Spark_SparkSQL 中定义 UDAF

参考文章

1.Spark 2.4.0编程指南--Spark SQL UDF和UDAF

https://yq.aliyun.com/articles/680259

2.spark-sql自定义函数UDF和UDAF

https://www.cnblogs.com/zzq-include/p/8758961.html

3.Sparksql UDAF,UDF详解(转)

https://blog.csdn.net/fengfengchen95/article/details/88681780

 

Base Spark 2.2.0

在SparkSQL中,我们常常使用系统定义的聚合函数 ,如Max,Min 等。我们看下如何在 SparkSQL 中 定义并使用 UDAF。

 

基本定义

如果实现自定义的udaf 需要继承自  UserDefinedAggregateFunction

下面是需要实现的各自函数以及参数的含义 :

package com.spark.test.offline.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructType}

/**
  * Created by szh on 2020/5/31.
  * @author szh
  */
class TestSum extends UserDefinedAggregateFunction{

  //聚合函数的输入参数数据类型
  override def inputSchema: StructType = ???

  //中间缓存的数据类型
  override def bufferSchema: StructType = ???

  //最终输出结果的数据类型
  override def dataType: DataType = ???

  //???
  override def deterministic: Boolean = ???

  //初始值,要是DataSet没有数据, 就返回该值
  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  /**
    *
    * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  /**
    * 相当于把每个分区的数据进行汇总
    * @param buffer1  分区一的数据
    * @param buffer2  分区二的数据
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  //计算最终的结果
  override def evaluate(buffer: Row): Any = ???

}

 

3个样例

我们分别实现了3个UDAF

customer_sum

customer_avg

customer_max

 

customer_sum

package com.spark.test.offline.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

/**
  * Created by szh on 2020/5/31.
  */
object CustomerSum extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(StructField("inputColumn",LongType) :: Nil)
  }

  //中间缓存的数据类型
  def bufferSchema: StructType = {
    StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)
  }

  //最终输出结果的数据类型
  def dataType: DataType = LongType

  def deterministic: Boolean = true

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
  }


  /**
    *
    * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
    * @param input
    */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
    if(!input.isNullAt(0)){
      buffer(0) =   buffer.getLong(0) + input.getLong(0)
    }
  }

  /**
    * 相当于把每个分区的数据进行汇总
    * @param buffer1  分区一的数据
    * @param buffer2  分区二的数据
    */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
  }


  //计算最终的结果
  def evaluate(buffer: Row): Long = buffer.getLong(0)
}

 

customer_avg

package com.spark.test.offline.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * Created by szh on 2020/5/31.
  */
object CustomerAvg extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(StructField("inputColumn",LongType) :: Nil)
  }

  //中间缓存的数据类型
  def bufferSchema: StructType = {
    StructType(StructField("sum",LongType) :: StructField("count",LongType) :: Nil)
  }

  //最终输出结果的数据类型
  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }


  /**
    *
    * @param buffer  相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
    * @param input
    */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
    if(!input.isNullAt(0)){
      buffer(0) = buffer.getLong(0) + input.getLong(0)   // salary
      buffer(1) = buffer.getLong(1) + 1  // count
    }
  }

  /**
    * 相当于把每个分区的数据进行汇总
    * @param buffer1  分区一的数据
    * @param buffer2  分区二的数据
    */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit={
    buffer1(0) = buffer1.getLong(0) +buffer2.getLong(0)  //   salary
    buffer1(1) = buffer1.getLong(1) +buffer2.getLong(1)  // count
  }


  //计算最终的结果
  def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}

 

 

customer_max

package com.spark.test.offline.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * Created by szh on 2020/5/31.
  */
object CustomerMax extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(StructField("inputColumn", LongType) :: Nil)
  }

  //中间缓存的数据类型
  def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }

  //最终输出结果的数据类型
  def dataType: DataType = LongType

  def deterministic: Boolean = true

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
  }


  /**
    *
    * @param buffer 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
    * @param input
    */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      if (input.getLong(0) > buffer.getLong(0)) {
        buffer(0) = input.getLong(0)
      }
    }
  }

  /**
    * 相当于把每个分区的数据进行汇总
    *
    * @param buffer1 分区一的数据
    * @param buffer2 分区二的数据
    */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    if (buffer2.getLong(0) > buffer1.getLong(0)) buffer1(0) = buffer2.getLong(0)
  }


  //计算最终的结果
  def evaluate(buffer: Row): Long = buffer.getLong(0)


}

 

示例代码

package com.spark.test.offline.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
  * Created by szh on 2020/5/31.
  */
object SparkSQLUdaf {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf
    conf
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      // .set("spark.kryo.registrationRequired", "true")
      //方法一
      .registerKryoClasses(
      Array(
        classOf[User]
        , classOf[scala.collection.mutable.WrappedArray.ofRef[_]]
      ))

    val spark = SparkSession
      .builder()
      .appName("sparkSql")
      .master("local[1]")
      .config(conf)
      .getOrCreate()

    val sc = spark.sparkContext
    sc.setLogLevel("ERROR")

    //user_id,project_id,score
    val rddA = sc.parallelize(Seq(
      (1, 1, 86.3)
      , (1, 2, 88.3)
      , (1, 3, 90.3)
      , (2, 1, 60.0)
    ))

    spark
      .createDataFrame(rddA)
      .toDF("user_id", "project_id", "score")
      .createTempView("user_info")

    spark.udf.register("com_sum", CustomerSum)
    val df = spark.sql("SELECT user_id, com_sum(score) FROM user_info GROUP BY user_id ")
    df.show()

    spark.udf.register("com_avg", CustomerAvg)
    val df2 = spark.sql("SELECT user_id, com_avg(score) FROM user_info GROUP BY user_id ")
    df2.show()

    spark.udf.register("com_max", CustomerMax)
    val df3 = spark.sql("SELECT com_max(score) FROM user_info ")
    df3.show()

    sc.stop()
    spark.stop()

  }

}

 

代码输出

+-------+-----------------------------------+
|user_id|customersum$(CAST(score AS BIGINT))|
+-------+-----------------------------------+
|      1|                                264|
|      2|                                 60|
+-------+-----------------------------------+

+-------+-----------------------------------+
|user_id|customeravg$(CAST(score AS BIGINT))|
+-------+-----------------------------------+
|      1|                               88.0|
|      2|                               60.0|
+-------+-----------------------------------+

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值