自定义UDAF函数实现去重统计效果

package com.erongda.bigdata.scala.spark.sql

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.Seq._
import scala.Predef.Set

/**
  * 自定义UDAF函数 实现去重效果
  */
object DistinctUDAF extends UserDefinedAggregateFunction {

  import scala.collection.immutable.Set
  var set: Set[Int] = Set()
  /**
    * 输入的数据类型
    * @return
    */
  override def inputSchema: StructType = StructType(
    Array(StructField("sal",IntegerType,nullable = true))
  )

  /**
    * 表示的是聚合过程中 缓冲临时变量的数据类型,也是封装在StructType里面
    * @return
    */
  override def bufferSchema: StructType = StructType(
   Array(StructField("sal_dis",IntegerType, nullable = true))
    )

  /**
    * 输出的数据类型
    * @return
    */
  override def dataType: DataType = IntegerType

  /**
    *
    * @return
    */
  override def deterministic: Boolean = true

  /**
    * 初始化 对聚合中的 临时缓冲的值
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
  }

  /**
    * TODO: 针对每个分区数据进行操作的
    *   表示的是针对每条数据进行聚合函数以后,对聚合缓冲临时变量的更新
    * @param buffer
    *               聚合缓冲临时变量
    * @param input
    *              输入变量
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //获得缓冲临时的变量值
//    val saldis = buffer.getSeq(0).toSet

    //获取输入传递进来的值
//    val inputSal = input.getDouble(0)


//    val value_set =  (buffer.getSeq(0).toSet + input.getDouble(0)).toSeq
    //更新缓冲数据的值
//    buffer.update(0, value_set)
      buffer(0)=buffer.getInt(0)+input.getInt(0)
      set += (input.getInt(0))
  }


  /**
    * TODO: 针对合并所有分区的聚合结果
    *  从字面意思看,就是合并的意思
    *     This is called when we merge two partially aggregated data together.
    *
    *  表示的是:
    *      正对不同分区 聚合的结果  进行合并操作
    *
    * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
    *     表示将两个聚合缓冲的数据,合并以后并存储到 buffer1中
    *
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //分别获得缓冲临时变量的值
//    val saldis = buffer1.getSeq(0).toSet
//
//    val saldis2 = buffer2.getSeq(0).toSet
//
//    val total  = (saldis ++ saldis2).toSeq
//    //合并更新
//    buffer1.update(0, total)
    buffer1(0)=buffer1.getInt(0)+buffer2.getInt(0)
     set ++= set
  }

  /**
    * 表示的是一句 聚合缓冲的临时变量 计算 聚合最终结果
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = {
    //a. 获取聚合缓冲数据
//    val saldis = buffer.getSeq(0).length

    //b.计算并返回
    set.size
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值