Spark2.2(四)用户自定义聚合函数

用户自定义聚合函数

package doc.df

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}

/**
  * @Program: doc.df
  * @Author: huangwei
  * @Date: 2019/9/16 17:34
  * @description: 用户自定义聚合函数
  *              无类型的用户自定义函数(Untyped User-Defined Aggregate Functions)
  *              实现无类型用户自定义聚合函数需要继承抽象类UserDefinedAggregateFunction,并重写该类的8个函数
  */
object UserDefindUntypedAggregate {

  object MyAverage extends UserDefinedAggregateFunction{
    // 1、inputSchema 定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的列表
    // 这里定义salary列的Schema,首先使用StructField声明salary列的名字salaryColumn,数据类型为Long,这里只输入salary这一列,所以StructField构成的列表只有一个元素
    // ::是Scala的操作符,与空集合Nil操作后生成一个列表
    override def inputSchema: StructType = StructType(StructField("salaryColumn",LongType)::Nil)

    // 2、bufferSchema 事实上需要计算salary平均值的时候,需要用到salary的总和sum和总个数count这样的中间数据,那么就使用bufferSchema来定义
    override def bufferSchema: StructType = StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)

    // 3、dataType 我们需要自定义聚合函数最终数据类型进行说明,使用dataType函数,这里salary的类型为Double类型
    override def dataType: DataType = DoubleType

    // 4、deterministic 用户对输入数据进行一致性检验,是一个布尔值,当为True时,表示对于同样的输入会得到同样的输出,因为对于同样的Salary输入,肯定要得到相同的Salary平均著,所以定义为true
    override def deterministic: Boolean = true

    // 5、initialize  用于初始化缓存数据,salary的缓存数据有两个:sum和count,需要初始化sum为0L,count为0L
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }

    // 6、update  当有新的输入数据时,更新缓存变量,这里有新的salary输入时,需要更新sum值,并将count加1
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      // 输入非空
      if (!input.isNullAt(0)){
        buffer(0) = buffer.getLong(0) + input.getLong(0)  // sum = sum+输入的salary
        buffer(1) = buffer.getLong(1) + 1                 // count = count + 1
      }
    }

    // 7、merge  将更新的缓存变量存入到缓存中
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    // 8、evalute 用于计算最后的结果,这里用于计算平均值
    override def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Spark SQL user-defined DataFrames aggregation example")
      .master("local")
      .getOrCreate()

    // 注册名为myAverage的自定义集成算子MyAverage
    spark.udf.register("myAverage", MyAverage)

    val df = spark.read.json("E:\\IdeaProjects\\SparkProject\\src\\main\\resources\\employess.json")
    df.createOrReplaceTempView("employee")
    df.show()
//    +-------+------+
//    |   name|salary|
//    +-------+------+
//    |Michael|  3000|
//    |   Andy|  4500|
//    | Justin|  3500|
//    |  Berta|  4000|
//    +-------+------+
    val avg_salary = spark.sql("SELECT myAverage(salary) as average_salary FROM employee")
    avg_salary.show()
//    +--------------+
//    |average_salary|
//    +--------------+
//    |        3750.0|
//    +--------------+

  }


}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值