UDAF概述
UDAF定义
class 类名 extends UserDefinedAggregateFunction{ }
UDAF的使用
val 对象名= new 类名
spark. udf. register( "自定义UDAF名称" , 对象名)
val df2: DataFrame = spark. sql( "select 字段,UDAF名称(字段) from userinfo group by 字段" )
UDAF示例
object SparkUDAFDemo {
def main( args: Array[ String ] ) : Unit = {
val spark = SparkSession. builder( ) . master( "local[*]" ) . appName( "UDAF" ) . getOrCreate( )
import spark. implicits. _
val df: DataFrame = spark. read. json( "in/user.json" )
val function = new MyAgeAvgFunction
spark. udf. register( "myAvgAge" , function)
df. createTempView( "userinfo" )
val df2: DataFrame = spark. sql( "select sex,myAvgAge(age) from userinfo group by sex" )
df2. show( )
}
}
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
new StructType( ) . add( StructField( "age" , LongType) )
}
override def bufferSchema: StructType = {
new StructType( ) . add( StructField( "sum" , LongType) ) . add( StructField( "count" , LongType) )
}
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize( buffer: MutableAggregationBuffer) : Unit = {
buffer( 0 ) = 0L
buffer( 1 ) = 0L
}
override def update( buffer: MutableAggregationBuffer, input: Row) : Unit = {
buffer( 0 ) = buffer. getLong( 0 ) + input. getLong( 0 )
buffer( 1 ) = buffer. getLong( 1 ) + 1
}
override def merge( buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
buffer1( 0 ) = buffer1. getLong( 0 ) + buffer2. getLong( 0 )
buffer1( 1 ) = buffer1. getLong( 1 ) + buffer2. getLong( 1 )
}
override def evaluate( buffer: Row) : Any = {
buffer. getLong( 0 ) . toDouble/ buffer. getLong( 1 )
}
}