SparkSql自定义UDF和UDAF函数
package com.spark.sparksql
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
import org.junit.Test
class $03_udf {
val spark = SparkSession.builder().master("local[4]")
.appName("test").getOrCreate()
import spark.implicits._
@Test
def udf:Unit={
val df = List( ("1001","zhangsan",20),("000102","lisi",30),("0123","wangwu",40) )
.toDF("id","name","age")
spark.udf.register("preId",prefixUserId _)
df.selectExpr("preId(id)","name","age").show()
}
def prefixUserId(id:String):String={
if (id.length<8){
"0"*(8-id.length) + id
}else{
id
}
}
@Test
def udaf1:Unit={
val df = List( ("1001","zhangsan",20),("000102","lisi",30),("0123","wangwu",40) ).toDF("id","name","age")
spark.udf.register("myAvg",new MyAvgAgg)
df.selectExpr("myAvg(age)").show()
import org.apache.spark.sql.functions._
spark.udf.register("myAvg2",udaf(new MyAvgAgg2))
df.selectExpr("myAvg2(age)").show()
}
}
case class Buff(var sum:Int, var count:Int)
class MyAvgAgg2 extends Aggregator[Int,Buff ,Double] {
override def zero: Buff = Buff(0, 0)
override def reduce(b: Buff, a: Int): Buff = {
Buff(b.sum + a, b.count + 1)
}
override def merge(b1: Buff, b2: Buff): Buff = {
Buff(b1.sum + b2.sum, b1.count + b2.count)
}
override def finish(reduction: Buff): Double = {
reduction.sum.toDouble / reduction.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product[Buff]
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
class MyAvgAgg extends UserDefinedAggregateFunction{
override def inputSchema: StructType = new StructType().add("input",IntegerType)
override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
.add("count",IntegerType)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0)
buffer.update(1,0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + input.getAs[Int](0)
buffer(1) = buffer.getAs[Int](1) + 1
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
}
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toDouble / buffer.getAs[Int](1)
}
}