前言:
最然spark提供的算子和函数非常丰富,但是对于一些特殊的业务需求,还是自定义函数比较好用,自定义函数一般有UDF和UDAF,UDF就是一对一的函数,UDAF是一对多函数
一:自定义UDF函数
步骤:
1.自定义方法–满足需求
2.注册方法
3.在sql语句中使用函数
//需求:正常情况下员工的工号是8位,现在工号长度不够,需要使用0来补全
object demo03 {
private val spark = SparkSession.builder().master("local[6]").appName("aggregation").getOrCreate()
import spark.implicits._
//创建数据
private val rdd: RDD[(String, String)] = spark.sparkContext.parallelize(Seq( ("010","张三"),("020","李四"),( "003","王五")))
//把数据转成表的形式
rdd.toDF("id","name").createTempView("employee")
//自定义UDF函数,该函数的功能是补全员工id的位数,所以参数是id,返回值类型是String
def fillID(id:String): String ={
"0"*(8-id.length)+id
}
def main(args: Array[String]): Unit = {
//注册自定义的UDF函数,参数1:函数的名称 , 参数2:自定义的方法转成函数的形式
spark.udf.register("fillID", fillID _)
//验证自定义函数
spark.sql("select fillID(ID),name from employee").show()
}
}
结果:
+--------------+----+
|UDF:fillID(ID)|name|
+--------------+----+
| 00000010| 张三|
| 00000020| 李四|
| 00000003| 王五|
+--------------+----+
二:自定义UDAF函数
步骤:
1.创建类继承UserDefinedAggregateFunction类
2.重写方法
3.测试自定义函数
//需求:计算员工平均工资
object demo04 {
//1.创建类去继承
class UdafFunction extends UserDefinedAggregateFunction{
//2.重写8个方法
/**
* 聚合函数输入参数的类型--返回值是StructType类型
* @return
*/
override def inputSchema: StructType = new StructType().add("input",IntegerType)
/**
* 缓冲区:在计算过程中需要用到的中间变量的类型
* 需要用到两个中间变量: 一个为sum【用来统计价格的总和】 一个为total【用来记录商品的个数】
* @return
*/
override def bufferSchema: StructType = {
new StructType().add("sum",IntegerType).add("total",IntegerType)
}
/**
* 指明返回值类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 是否保存数据的一致性,一般设为true
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化缓冲区,设置sum=0,total=0
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化sum
buffer(0)=0
//初始化total
buffer(1)=0
}
/**
* 更新缓冲区的值
* 每进来一条数据,sum需要累加,total需要+1
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//sum进行累加
buffer(0) = buffer.getAs[Int](0)+input.getAs[Int](0)
//total自增1
buffer(1) = buffer.getAs[Int](1)+1
}
/**
* 合并缓冲区
* 将所有的缓冲区的sum与total值继续累加
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//累加sum
buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
//累加total
buffer1(1) = buffer1.getAs[Int](1)+buffer2.getAs[Int](1)
}
/**
* 返回最终结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
}
}
def main(args: Array[String]): Unit = {
//创建入口
val spark = SparkSession.builder().master("local[6]").appName("aggregation").getOrCreate()
//导入隐式转换
import spark.implicits._
//创建数据
val rdd: RDD[(String, Int)] = spark.sparkContext.parallelize(Seq(("huawei",1000),("thinkpad",2000),("redmi",600)))
//创建临时表
rdd.toDF("name","price").createTempView("product")
//注册自定义UDAF函数
/**
* 源码
* def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
* 参数1:名称,随便定义,用于使用函数的时候
* 参数2:UserDefinedAggregateFunction类型 , 是自定义的类的父类,所以new UdafFunction即可
*/
spark.udf.register("UdafFunction",new UdafFunction)
//使用自定义UDAF函数
spark.sql("select UdafFunction(price) from product").show()
}
}
结果:
+-------------------+
|udaffunction(price)|
+-------------------+
| 1200.0|
+-------------------+