Spark SQL——UDAF

本文所涉及到的代码已上传到https://github.com/xtxxtxxtx/commerce

Spark 的DataFrame提供了通用的聚合方法,比如count()、countDistinct()、avg()、max()、min()等等,但是这些函数是针对DateFrame设计的,当然Spark SQL也有类型安全的版本,Java和Scala语言接口都有,这些适用于强类型的DataSet。

本文主要讲解一下Spark提供的两种聚合函数接口:

  1. UserDefinedAggregateFunction
  2. Aggregator

这两个接口基本满足了用户自定义函数的需求。

UserDefinedAggregateFunction

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数属于弱类型UDAF函数。

该类的源码:

abstract class UserDefinedAggregateFunction extends Serializable {
StructType代表的是该聚合函数输入参数的类型。例如,一个UDAF实现需要两个输入参数,
类型分别是DoubleType和LongType,那么该StructType格式如下:
    new StructType()
    .add("doubleInput",DoubleType)
    .add("LongType",LongType)
那么该udaf就只会识别,这种类型的输入的数据。
  def inputSchema: StructType
   该StructType代表aggregation buffer的类型参数。例如,一个udaf的buffer有
   两个值,类型分别是DoubleType和LongType,那么其格式将会如下:
     new StructType()
      .add("doubleInput", DoubleType)
      .add("longInput", LongType)
     也只会适用于类型格式如上的数据
   
  def bufferSchema: StructType

    dataTypeda代表该UDAF的返回值类型
  def dataType: DataType

    如果该函数是确定性的,那么将会返回true,例如,给相同的输入,就会有相同
    的输出
  def deterministic: Boolean
    
    初始化聚合buffer,例如,给聚合buffer以0值
    在两个初始buffer调用聚合函数,其返回值应该是初始函数自身,例如
    merge(initialBuffer,initialBuffer)应该等于initialBuffer。
  def initialize(buffer: MutableAggregationBuffer): Unit

    利用输入输入去更新给定的聚合buffer,每个输入行都会调用一次该函数
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

    合并两个聚合buffer,并且将更新的buffer返回给buffer1
    该函数在聚合并两个部分聚合数据集的时候调用
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

    计算该udaf在给定聚合buffer上的最终结果
  def evaluate(buffer: Row): Any


    使用给定的Column作为输入参数,来为当前UDAF创建一个Column
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = false)
Column(aggregateExpression)
  }

    使用给定Column去重后的值作为参数来生成一个Column
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = true)
Column(aggregateExpression)
  }
}

/**
 * A `Row` representing a mutable aggregation buffer.
 *
 * This is not meant to be extended outside of Spark.
 *
 * @since 1.5.0
 */
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {

/** Update the ith value of this buffer. */
  def update(i: Int, value: Any): Unit
}

案例实现:

class GroupConcatDistinct extends UserDefinedAggregateFunction{

  //UDAF:输入类型是String
  override def inputSchema: StructType = StructType(StructField("cityInfo", StringType) :: Nil)

  //缓冲区类型
  override def bufferSchema: StructType = StructType(StructField("bufferCityInfo", StringType) :: Nil)

  //输出类型是String
  override def dataType: DataType = StringType

  // 一致性检验,如果为true那么输入不变的情况下结果也是不变的
  override def deterministic: Boolean = true

  /**
    * 设置聚合中间buffer的中间值
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
  }

  //用输入数据input更新buffer值,类似于combineByKey
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val cityInfo = input.getString(0)
    var bufferCityInfo = buffer.getString(0)

    if (!bufferCityInfo.contains(cityInfo)){
      if ("".equals(bufferCityInfo)){
        bufferCityInfo += cityInfo
      }else{
        bufferCityInfo += "," + cityInfo
      }
      buffer.update(0, bufferCityInfo)
    }
  }

  /**
    * 合并这两个buffer,将buffer2合并到buffer1,在合并两个分区聚合结果时候会用到,类似于reduceByKey
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    var bufferCityInfo1 = buffer1.getString(0)
    // bufferCityInfo2: cityId1:cityName1, cityId2:cityName2
    val bufferCityInfo2 = buffer2.getString(0)

    for (cityInfo <- bufferCityInfo2.split(",")){
      if (!bufferCityInfo1.contains(cityInfo)){
        if ("".equals(bufferCityInfo1)){
          bufferCityInfo1 += cityInfo
        }else{
          bufferCityInfo1 += "," + cityInfo
        }
      }
    }
    buffer1.update(0, bufferCityInfo1)
  }

  //计算并返回最终的聚合结果
  override def evaluate(buffer: Row): Any = {
    buffer.getString(0)
  }
}

Aggregator

通过继承Aggregator来实现强类型自定义聚合函数。

该类的源码:

      举个栗子
 *   val customSummer =  new Aggregator[Data, Int, Int] {
 *     def zero: Int = 0
 *     def reduce(b: Int, a: Data): Int = b + a.i
 *     def merge(b1: Int, b2: Int): Int = b1 + b2
 *     def finish(r: Int): Int = r
 *   }.toColumn()
 *
 *   val ds: Dataset[Data] = ...
 *   val aggregated = ds.select(customSummer)
 * }}}
 * @tparam IN The input type for the aggregation.
 * @tparam BUF The type of the intermediate value of the reduction.
 * @tparam OUT The type of the final output result.
 * @since 1.6.0
 */
@Experimental
@InterfaceStability.Evolving
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

    该剧和函数的0值。需要满足对于任何输入b,那么b+zero=b
  def zero: BUF

    聚合两个值产生一个新的值,为了提升性能,该函数会修改b,然后直接返回b,而
    不适新生成一个b的对象。

  def reduce(b: BUF, a: IN): BUF
    合并两个中间值
  def merge(b1: BUF, b2: BUF): BUF

    转换reduce的输出
  def finish(reduction: BUF): OUT

    为中间值类型提供一个编码器
  def bufferEncoder: Encoder[BUF]

    为最终的输出结果提供一个编码器 
  def outputEncoder: Encoder[OUT]
    
    将该聚合函数返回为一个TypedColumn,目的是为了能在Dataset中使用
  def toColumn: TypedColumn[IN, OUT] = {
implicit val bEncoder = bufferEncoder
implicit val cEncoder = outputEncoder

val expr =
AggregateExpression(
TypedAggregateExpression(this),
        Complete,
        isDistinct = false)

new TypedColumn[IN, OUT](expr, encoderFor[OUT])
  }
}

案例实现:

case class Employee(name : String, salary : Long)
case class Average(var sum : Long, var count : Long)

class MyAverage extends Aggregator[Employee, Average, Double]{
  
  //计算并返回最终的聚合结果
  override def zero: Average = Average(0L, 0L)
  
  //根据传入的参数值更新buffer值
  override def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
  }
  
  //合并两个buffer值,将buffer2合并到buffer1
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  
  //计算输出
  override def finish(reduction: Average): Double = {
    reduction.sum.toDouble / reduction.count
  }

  /**
    * 设定中间值类型的编码器需要转换成case类
    * Encoders.product是将scala元组和case类转换的编码器
    */
  override def bufferEncoder: Encoder[Average] = Encoders.product
  
  //设定最终输出值的编码器
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值