自定义函数:
内置函数无法满足需求时,通过自定义方式扩展函数库增强功能。
单行函数
函数应用于每一行数据,返回处理结果。
如:upper(小写转大写),lower(大写转小写)
object FunctionExample{
def main(args: Array [String]) : Unit= {
val spark = SparkSession.builder().appName("function example").master("local[*]").getOrCreate()
import spark.implicits._
val df = spark
.sparkContext
.makeRDD(List((1, "zs", true), (2, "ls", false), (3, "ww", true)))
.toDF("id","name","sex")
spark
.udf
.register("sex_convert",(sex:Boolean) => {
sex match{
case true => "男性"
case false => "女性"
}
})
df
.createOrReplaceTempView("t_user")
spark
.sql("select id,name,sex_convert(sex) from t_user")
.show()
spark.stop()
}
}
多行函数
对多行数据进行操作返回单行结果。
如:max(salary),count(),avg(xxx),avg等
object FunctionsExample2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("function example").master("local[*]").getOrCreate()
import spark.implicits._
val df = spark
.sparkContext
.makeRDD(List((1, "zs", true, 1000.0), (2, "ls", false, 2000.0), (3, "ww", true, 3000.0)))
.toDF("id", "name", "sex", "salary")
// 自定义多行函数
spark
// user-defined functions
.udf
.register("my_sum", new UserDefinedAggregateFunction {
/**
* 输入的字段类型信息 my_sum(salary,age,xxx)
*
* @return
*/
override def inputSchema: StructType = {
new StructType().add("salary", DoubleType)
}
/**
* 中间结果类型
*
* @return
*/
override def bufferSchema: StructType = new StructType().add("total", DoubleType)
/**
* 数据类型 最终计算结果的返回类型 double + double + double = double
*
* @return
*/
override def dataType: DataType = DoubleType
/**
* 表示多行函数的输入类型是否和计算结果输出类型一致
*
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化缓冲区方法
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 缓冲区中的第一个元素的初始值为0
buffer.update(0, 0.0)
}
/**
* 更新方法
*
* @param buffer 计算结果的缓冲区
* @param input 行对象
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val current = input.getDouble(0) // 多行函数参数构成的Row
val history = buffer.getDouble(0)
buffer.update(0, current + history) // 更新最新的聚合结果
}
/**
* 合并方法:多个缓冲区的结果合并到一起 buffer2的结果合并到buffer1中
*
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val b1 = buffer1.getDouble(0)
val b2 = buffer2.getDouble(0)
buffer1.update(0, b1 + b2)
}
/**
* 评估方法:返回多行函数计算结果方法
*
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = buffer.getDouble(0)
})
df
.createOrReplaceTempView("t_user")
// 使用
spark
.sql("select my_sum(salary) from t_user")
.show()
spark.stop()
}
}