用户自定义函数
–1.UDF
package com.qf.sql.day02
import org.apache.spark.sql.{DataFrame, SparkSession}
object _06UserDefineFunction {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").appName("test").getOrCreate()
val df: DataFrame = spark.read.json("file:///D:\\Users\\Michael\\Documents\\IdeaProjects\\sz2003\\sz2003_sparksql01/data/emp.json")
df.createTempView("emp")
spark.udf.register("getLevel",getLevel _)
val sql =
"""
|select ename,
|job,
|sal,
|getLevel(sal) as level
|from
|emp
|""".stripMargin
spark.sql(sql).show()
spark.stop()
}
def getLevel(sal:Int)={
if(sal>3000){
"level1"
}else if(sal>1500){
"level2"
}else{
"level3"
}
}
}
–2.UDAF
package day02
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
case class Score(id:Int,name:String,score:Double)
object _03UDAF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("udaf").master("local").getOrCreate()
val list = List(
Score(1,"张三",99.9),
Score(2,"李四",88.9),
Score(3,"小明",77.9),
Score(1,"张三",91.9),
Score(2,"李四",81.9),
Score(3,"小明",71.9)
)
import spark.implicits._
val ds: Dataset[Score] = list.toDS()
ds.createTempView("tmp")
spark.udf.register("myavg",new MyAvgUDAF)
val sql = "select avg(score), myavg(score),name from tmp group by name"
spark.sql(sql).show()
}
}
class MyAvgUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
StructType(Array(StructField("score",DataTypes.DoubleType)))
}
override def bufferSchema: StructType = {
StructType(Array(
StructField("sum",DataTypes.DoubleType),
StructField("count",DataTypes.IntegerType)
))
}
override def dataType: DataType = DataTypes.DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0d)
buffer.update(1,0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val score: Double = input.getAs[Double](0)
buffer.update(0,buffer.getDouble(0)+score)
buffer.update(1,buffer.getInt(1)+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
buffer1.update(1,buffer1.getInt(1)+buffer2.getInt(1))
}
override def evaluate(buffer: Row): Any = buffer.getDouble(0)/buffer.getInt(1)
}