首先用Scala写一个UDAF函数
import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class UDAFMedian extends UserDefinedAggregateFunction {
// 聚合函数的输入数据结构
def inputSchema: StructType =
StructType(StructField("value", DoubleType) :: Nil)
// 缓存区数据结构
def bufferSchema: StructType = StructType(
StructField("data_list", ArrayType(DoubleType, false)) :: Nil
)
// 聚合函数返回值数据类型
def dataType: DataType = DoubleType
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
def deterministic: Boolean = true
// 初始化缓冲区
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = new ArrayBuffer[Double]()
}
// 给聚合函数传入一条新数据时的处理逻辑
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var bufferVal = buffer.getAs[WrappedArray[Double]](0).toBuffer
bufferVal += input.getAs[Double](0)
buffer(0) = bufferVal
}
// 合并聚合函数缓冲区
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[WrappedArray[Double]](0) ++ buffer2.getAs[WrappedArray[Double]](0)
}
// 计算最终结果
def evaluate(buffer: org.apache.spark.sql.Row): Any = {
val sortedWindow = buffer.getAs[WrappedArray[Double]](0).sorted.toBuffer
val windowSize = sortedWindow.size
if (windowSize % 2 == 0) {
val index = windowSize / 2
(sortedWindow(index) + sortedWindow(index - 1)) / 2
} else {
sortedWindow((windowSize + 1) / 2 - 1)
}
}
}
其次,注册该UDAF并使用
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ListBuffer
object TestMedian {
def main(args: Array[String]): Unit = {
val ss: SparkSession = SparkSession.builder().master("local").enableHiveSupport().getOrCreate()
// 注册自定义的UFAF函数,并命名为median
ss.sqlContext.udf.register("median", new UDAFMedian())
// 在sql中使用median函数,求中位数
val sql = "select class, median(score) from scores group by class"
val rdd = ss.sql(sql).rdd.collect()
// 将sql结果存入ListBuffer
val result:ListBuffer[String] = new ListBuffer[String]()
for (i <- 0 to rdd.length - 1) {
val line: StringBuffer = new StringBuffer()
for (j <- 0 to rdd(i).length - 1) {
val value = rdd(i)(j)
if (Option(value) == None) {
line.append("")
} else {
line.append(value.toString)
}
if (j < rdd(i).length - 1) {
line.append(",")
}
}
result.append(line.toString)
}
}
}
官方UDAF示例参考: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html