一、用户可以通过spark.udf功能添加自定义函数,实现自定义功能
小需求案例:在名字前面加Name前缀,如Name:zhangxiaoming。
mport org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object Spark02_SparkSQL_UDF {
def main(args: Array[String]): Unit = {
val SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(SparkConf).getOrCreate()
val df =spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("prefixName",(name:String) =>{
"Name:"+name
})
spark.sql("select age,prefixName(username) from user").show()
spark.stop()
}
}
执行结果:
二、创建UDF函数
需求二:计算平均年龄
方案一:采用RDD模式
import org.apache.spark.{SparkConf, SparkContext}
object Spark00_SparkSQL_UDF {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName("app").setMaster("local[*]")
val sc: SparkContext = new SparkContext(conf)
val res: (Int, Int) = sc.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40))).map {
case (name, age) => {
(age, 1) }
}.reduce {
(t1,t2) => {
(t1._1 + t2._1, t1._2 + t2._2) }
}
println(res._1/res._2)
//TODO 关闭环境
sc.stop()
}
}
运行结果:
方案二:采用累加器
import org.apache.spark.util.AccumulatorV2
object Spark03_SparkSQL_AccumulatorV2 {
def main(args: Array[String]): Unit = {
class MyAC extends AccumulatorV2[Int,Int]{
var sum:Int = 0
var count:Int = 0
override def isZero: Boolean = {
return sum ==0 && count == 0 }
override def copy(): AccumulatorV2[Int, Int] = {
val newMyAc = new MyAC
newMyAc.sum = this.sum
newMyAc.count = this.count
newMyAc }
override def reset(): Unit = {
sum =0
count = 0 }
override def add(v: Int): Unit = {
sum += v
count += 1 }
override def merge(other: AccumulatorV2[Int, Int]): Unit = {
other match {
case o:MyAC=>{
sum += o.sum
count += o.count }
case _=>
}
}
override def value: Int = sum/count
}
}
}
方案三:采用UDAF
3.1弱类型:
import org.apache.parquet.filter2.predicate.Operators.UserDefined
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
object Spark03_SparkSQL_UDAF {
def main(args: Array[String]): Unit = {
//TODO 创建SparkSQL的运行环境
val SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(SparkConf).getOrCreate()
val df =spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("ageAvg",new MyAvgUDAF())
spark.sql("select ageAvg(age) from user").show()
//TODO 关闭环境
spark.stop()
}
/*
自定义聚合函数类:计算年龄的平均值
1.继承UserDefinedAggregateFunction
2.重写方法(8)
*/
class MyAvgUDAF extends UserDefinedAggregateFunction{
//输入数据结构
override def inputSchema: StructType = {
StructType(
Array(
StructField("age",LongType)
)
)
}
//缓冲区数据的结构
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total",LongType),
StructField("count",LongType)
)
)
}
//函数计算结果的数据类型:Out
override def dataType: DataType = LongType
//函数的稳定性
override def deterministic: Boolean = true
//缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//buffer(0)=0L
//buffer(1)=0L 这个写法跟下面的写法功能一致,只是写法不同而已。
buffer.update(0,0L)
buffer.update(1,0L)
}
//根据输入的值更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getLong(0)+input.getLong(0))
buffer.update(1,buffer.getLong(1)+1)
}
//缓冲区数据合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
}
//计算平均值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
}
运行结果:
3.2强类型
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
object Spark03_SparkSQL_UDAF1 {
def main(args: Array[String]): Unit = {
//TODO 创建SparkSQL的运行环境
val SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(SparkConf).getOrCreate()
val df =spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("ageAvg",functions.udaf(new MyAvgUDAF()))
spark.sql("select ageAvg(age) from user").show()
//TODO 关闭环境
spark.stop()
}
/*
自定义聚合函数类:计算年龄的平均值
1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型
IN:输入的数据类型Long
BUF:缓冲区的数据类型Buff
OUT:输出的数据类型Long
2.重写方法(6)
*/
case class Buff(var total:Long, var count:Long)
class MyAvgUDAF extends Aggregator[Long,Buff,Long]{
//z & zero:初始化或零值
//缓冲区的初始化
override def zero: Buff = {
Buff(0L,0L)
}
//根据输入的数据更新缓冲区的数据
override def reduce(buff: Buff, in: Long): Buff = {
buff.total = buff.total + in
buff.count = buff.count + 1
buff
}
//合并缓冲区
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.total = buff1.total + buff2.total
buff1.count = buff1.count + buff2.count
buff1
}
//计算结果
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
//缓冲去的编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
//输出的编码操作
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
运行结果: