本文所涉及到的代码已上传到https://github.com/xtxxtxxtx/commerce
Spark 的DataFrame提供了通用的聚合方法,比如count()、countDistinct()、avg()、max()、min()等等,但是这些函数是针对DateFrame设计的,当然Spark SQL也有类型安全的版本,Java和Scala语言接口都有,这些适用于强类型的DataSet。
本文主要讲解一下Spark提供的两种聚合函数接口:
- UserDefinedAggregateFunction
- Aggregator
这两个接口基本满足了用户自定义函数的需求。
UserDefinedAggregateFunction
通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数属于弱类型UDAF函数。
该类的源码:
abstract class UserDefinedAggregateFunction extends Serializable {
StructType代表的是该聚合函数输入参数的类型。例如,一个UDAF实现需要两个输入参数,
类型分别是DoubleType和LongType,那么该StructType格式如下:
new StructType()
.add("doubleInput",DoubleType)
.add("LongType",LongType)
那么该udaf就只会识别,这种类型的输入的数据。
def inputSchema: StructType
该StructType代表aggregation buffer的类型参数。例如,一个udaf的buffer有
两个值,类型分别是DoubleType和LongType,那么其格式将会如下:
new StructType()
.add("doubleInput", DoubleType)
.add("longInput", LongType)
也只会适用于类型格式如上的数据
def bufferSchema: StructType
dataTypeda代表该UDAF的返回值类型
def dataType: DataType
如果该函数是确定性的,那么将会返回true,例如,给相同的输入,就会有相同
的输出
def deterministic: Boolean
初始化聚合buffer,例如,给聚合buffer以0值
在两个初始buffer调用聚合函数,其返回值应该是初始函数自身,例如
merge(initialBuffer,initialBuffer)应该等于initialBuffer。
def initialize(buffer: MutableAggregationBuffer): Unit
利用输入输入去更新给定的聚合buffer,每个输入行都会调用一次该函数
def update(buffer: MutableAggregationBuffer, input: Row): Unit
合并两个聚合buffer,并且将更新的buffer返回给buffer1
该函数在聚合并两个部分聚合数据集的时候调用
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
计算该udaf在给定聚合buffer上的最终结果
def evaluate(buffer: Row): Any
使用给定的Column作为输入参数,来为当前UDAF创建一个Column
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
}
使用给定Column去重后的值作为参数来生成一个Column
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
}
/**
* A `Row` representing a mutable aggregation buffer.
*
* This is not meant to be extended outside of Spark.
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** Update the ith value of this buffer. */
def update(i: Int, value: Any): Unit
}
案例实现:
class GroupConcatDistinct extends UserDefinedAggregateFunction{
//UDAF:输入类型是String
override def inputSchema: StructType = StructType(StructField("cityInfo", StringType) :: Nil)
//缓冲区类型
override def bufferSchema: StructType = StructType(StructField("bufferCityInfo", StringType) :: Nil)
//输出类型是String
override def dataType: DataType = StringType
// 一致性检验,如果为true那么输入不变的情况下结果也是不变的
override def deterministic: Boolean = true
/**
* 设置聚合中间buffer的中间值
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
}
//用输入数据input更新buffer值,类似于combineByKey
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val cityInfo = input.getString(0)
var bufferCityInfo = buffer.getString(0)
if (!bufferCityInfo.contains(cityInfo)){
if ("".equals(bufferCityInfo)){
bufferCityInfo += cityInfo
}else{
bufferCityInfo += "," + cityInfo
}
buffer.update(0, bufferCityInfo)
}
}
/**
* 合并这两个buffer,将buffer2合并到buffer1,在合并两个分区聚合结果时候会用到,类似于reduceByKey
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var bufferCityInfo1 = buffer1.getString(0)
// bufferCityInfo2: cityId1:cityName1, cityId2:cityName2
val bufferCityInfo2 = buffer2.getString(0)
for (cityInfo <- bufferCityInfo2.split(",")){
if (!bufferCityInfo1.contains(cityInfo)){
if ("".equals(bufferCityInfo1)){
bufferCityInfo1 += cityInfo
}else{
bufferCityInfo1 += "," + cityInfo
}
}
}
buffer1.update(0, bufferCityInfo1)
}
//计算并返回最终的聚合结果
override def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}
Aggregator
通过继承Aggregator来实现强类型自定义聚合函数。
该类的源码:
举个栗子
* val customSummer = new Aggregator[Data, Int, Int] {
* def zero: Int = 0
* def reduce(b: Int, a: Data): Int = b + a.i
* def merge(b1: Int, b2: Int): Int = b1 + b2
* def finish(r: Int): Int = r
* }.toColumn()
*
* val ds: Dataset[Data] = ...
* val aggregated = ds.select(customSummer)
* }}}
* @tparam IN The input type for the aggregation.
* @tparam BUF The type of the intermediate value of the reduction.
* @tparam OUT The type of the final output result.
* @since 1.6.0
*/
@Experimental
@InterfaceStability.Evolving
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
该剧和函数的0值。需要满足对于任何输入b,那么b+zero=b
def zero: BUF
聚合两个值产生一个新的值,为了提升性能,该函数会修改b,然后直接返回b,而
不适新生成一个b的对象。
def reduce(b: BUF, a: IN): BUF
合并两个中间值
def merge(b1: BUF, b2: BUF): BUF
转换reduce的输出
def finish(reduction: BUF): OUT
为中间值类型提供一个编码器
def bufferEncoder: Encoder[BUF]
为最终的输出结果提供一个编码器
def outputEncoder: Encoder[OUT]
将该聚合函数返回为一个TypedColumn,目的是为了能在Dataset中使用
def toColumn: TypedColumn[IN, OUT] = {
implicit val bEncoder = bufferEncoder
implicit val cEncoder = outputEncoder
val expr =
AggregateExpression(
TypedAggregateExpression(this),
Complete,
isDistinct = false)
new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
}
案例实现:
case class Employee(name : String, salary : Long)
case class Average(var sum : Long, var count : Long)
class MyAverage extends Aggregator[Employee, Average, Double]{
//计算并返回最终的聚合结果
override def zero: Average = Average(0L, 0L)
//根据传入的参数值更新buffer值
override def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
//合并两个buffer值,将buffer2合并到buffer1
override def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
//计算输出
override def finish(reduction: Average): Double = {
reduction.sum.toDouble / reduction.count
}
/**
* 设定中间值类型的编码器需要转换成case类
* Encoders.product是将scala元组和case类转换的编码器
*/
override def bufferEncoder: Encoder[Average] = Encoders.product
//设定最终输出值的编码器
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}