title:sparksql自定义函数
一、sparksql自定义函数
spark-sql与hive类似,可以自定义函数
- UDF(user-defined-function):最基本的自定义函数,类似to_char,to_date.
- UDAF(user-defined-aggregation-function):用户自定义聚合函数,类似在group by之后使用的sum,avg等
- UDTF(user-defined-Table-Generating-function):用户自定义生成函数,有点像stream里面的flatMap
二、自定义函数编程
2.1UDF编程
一对一。在上下文对象的下面编写
格式:
spark.udf.register("自定义函数名",(字段名)=>{ //字段名:表示把那一列的数据输入进来
函数体
})
案例:把性别用0,1表示
//自定义函数 ==》UDF
spark.udf.register("sexToInt",(sex:String)=>{
//采用什么方式,将M和F 变成 1 或者 0
sex match {
case "M" => 1
case "F" => 0
case _ => -1
}
})
2.2UDAF编程
用户自定义聚合函数,编写一个类似avg()函数的自定义函数。
UDAF自定义函数需要继承UserDefinedAggregateFunction类,并实现其中的8个方法
2.2.1主类代码
使用之前要先注册:spark.udf.register(“UDAFavg”,UDAFavg)
spark.udf.register("UDAFavg",UDAFavg)
2.2.2自定义代码
package Day.Day6
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType}
//继承uadf
//scala中extends即可继承也可实现
object UDAFavg extends UserDefinedAggregateFunction{
/**
* 分析需要定义函数的计算逻辑和返回值类型
* 平均值怎么算?
* 总和 /总数
* @return UDAFavg(sal,double)
*/
/**
* 1.输入数据的数据类型
* @return
*/
override def inputSchema: StructType = {
StructType(Array(
StructField("v1",DoubleType)
))
}
/**
* 2.缓存字段的数据类型
* @return
*/
override def bufferSchema: StructType = {
StructType(Array(StructField("b1",DoubleType),
StructField("b2",IntegerType)))
}
/**
* 3.最终的函数的返回值类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 4.输入类型与输出类型是否一致
* @return
*/
override def deterministic: Boolean = true
/**
* 5.初始化缓存区的值
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0)
}
/**
* 6.缓存区数据输入开始更新(数据的更新是每个分区中都会进行的)
* @param buffer 缓存区内数据
* @param input 输入的数据
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//先取出数据,然后计算
//1.读取输入的数据
val in = input.getDouble(0)
//2.取出缓存区的数据
val b1 = buffer.getDouble(0)
val b2 = buffer.getInt(1)
//3.更新数据
buffer.update(0,in+b1)
buffer.update(1,b2+1)
}
/**
* 7.合并,在全局范围内,将每个分区中得到的 临时的总和和总数再次聚合
* @param buffer1 分区1缓存的数据
* @param buffer2 分区2缓存的数据
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//1.获取输入数据
val b3 = buffer2.getDouble(0)
val b4 = buffer2.getInt(1)
//2.获取缓存区的数据
val b1 = buffer1.getDouble(0)
val b2 = buffer1.getInt(1)
//3.更新数据
buffer1.update(0,b1+b3)
buffer1.update(1,b2+b4)
}
/**
* 8.计算平均值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getInt(1)
}
}