1.UDF
注:以下的SparkSQL初始化方式不是最新的,请参考上篇博客进行修改
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
object UDFTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("udf").setMaster("local")
val sc = new SparkContext(conf)
val hiveSQLContext = new HiveContext(sc)
hiveSQLContext.udf.register("toUpper",name =>{
if (name!=null){
name.toString.toUpperCase
}else{
" "
}
})
hiveSQLContext.udf.register("strLength",name=>{
if(name!=null){
name.toString.length
}else{
0
}
})
hiveSQLContext.sql("select toUpper(name) from student")
hiveSQLContext.sql("select strLength(name) from student")
}
}
2.UDAF
(1)1.6.0版本
package lesson02
import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types._
/**
*
* 需求:
* 需要求这家公司的所有员工的平均工资
* 思路:
* 1)先求出所有员工的工资 countSalary
* 2)求出员工的总数 count
* 3) 平均工资countSalary / count
*/
object UDAFTest extends UserDefinedAggregateFunction{
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("udf").setMaster("local")
val sc = new SparkContext(conf)
val hiveSQLContext = new HiveContext(sc)
hiveSQLContext.udf.register("avg_salary",UDAFTest)
hiveSQLContext.sql("select avg_salary(salary) from worker")
}
//定义输入的数据类型
override def inputSchema: StructType = StructType(
StructField("salary",DoubleType,true)::Nil
)
//定义输出的数据类型
override def dataType: DataType = DoubleType
/*
* 一般我们要完成聚合函数的功能,需要一些中间变量来帮忙完成。
* 然后可以在这儿去定时临时的缓存变量
* 根据我们的分析需要定义两个;
* countSalay: 用来记录所有员工的总工资
* count:用来统计总人数的
*/
override def bufferSchema: StructType = StructType{
StructField("countSalary",DoubleType,true)::
StructField("count",IntegerType,true)::Nil
}
//给参与计算的中间变量赋初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0,0.0)
buffer(1,0)
}
/* * 修改 中间的结果值
* @param buffer 上一次
* @param input 这次
* */
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val countSalary = buffer.getDouble(0)
val count = buffer.getInt(1)
val salary = input.getDouble(0)
buffer(0,salary+countSalary)
buffer(1,1+count)
}
// 全部汇总
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val countSalary1 = buffer1.getDouble(0)
val count1 = buffer1.getInt(1)
val countSalary2 = buffer2.getDouble(0)
val count2 = buffer2.getInt(1)
buffer1(0,countSalary1+countSalary2)
buffer1(1,count1+count2)
}
//获取最后的结果值
override def evaluate(buffer: Row): Any = {
val countSalary = buffer.getDouble(0)
val count = buffer.getInt(1)
countSalary / count
}
override def deterministic: Boolean = true
}
(2)2.2.0版本
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
object MyAverage extends UserDefinedAggregateFunction {
// Data types of input arguments of this aggregate function
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// Data types of values in the aggregation buffer
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// The data type of the returned value
def dataType: DataType = DoubleType
// Whether this function always returns the same output on the identical input
def deterministic: Boolean = true
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// Calculates the final result
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
// Register the function to access it
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+