UDF
一路输入,一路输出
练习需求:模拟获取字符串长度
准备文件
def main(args: Array[String]): Unit = {
//创建Session对象
val spark = SparkSession
.builder() //构建器
.appName("sparkSQL") //序名称程
.master("local[*]") //执行方式:本地
.getOrCreate() //创建对象
//读取数据
val df: DataFrame = spark.read.json("file:///D:\\spark.test\\datas\\people.json")
//方法1:自定义UDF并注册
spark.udf.register("UDFlenth1",(x : String) => x.length)
//方法2:提供一个函数 再注册 register[提供函数返回值类型,函数参数类型]
spark.udf.register[Int,String]("UDFlenth2",lenths)
//定义提供函数实现
def lenths(x : String): Int ={
x.length
}
//建立数据视图表
df.createOrReplaceTempView("people")
spark.sql("select UDFlenth1(name) from people").show()
spark.stop()
}
UDAF
多路输入,一路输出
类似于combineByKey,需要提供一个类继承UserDefinedAggregateFunction,实现抽象方法
练习需求:模拟avg()
object sparkSQL09 {
def main(args: Array[String]): Unit = {
//创建Session对象
val spark = SparkSession
.builder() //构建器
.appName("sparkSQL") //序名称程
.master("local[*]") //执行方式:本地
.getOrCreate() //创建对象
//读取数据
val df: DataFrame = spark.read.json("file:///D:\\spark.test\\datas\\emo.json")
//建立数据视图表
df.createOrReplaceTempView("emp")
//注册UDAF函数
spark.udf.register("MyAvg",new MyUDAF)
//使用
spark.sql("select MyAvg(salary) from emp").show()
spark.stop()
}
}
class MyUDAF extends UserDefinedAggregateFunction{
//输入数据的Schema信息
override def inputSchema: StructType = StructType(
List(StructField("salary",DoubleType,true))
)
//每一个分区中的共享变量 提供分区中聚合之后得到的结果集存储的位置
override def bufferSchema: StructType =
StructType(List(
StructField("sum",DoubleType,true), //工资的总和
StructField("count",DoubleType,true) //分区内工资累加的次数
))
//返回值的数据类型,表示UDAF函数输出结果的输出类型
override def dataType: DataType = DoubleType
//如果有相同输入 是否有相同输出
override def deterministic: Boolean = true //默认为true
//对当前Buffer中属性进行初始化操作,对每个分区进行变量赋值操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//工资总和的赋值 sum
buffer(0) = 0.0
//工资累加次数的赋值 count
buffer(1) = 0.0
}
//对分区内数据进行聚合操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(!input.isNullAt(0)){ //判断只要薪水不是空的
//计算一行薪水的工资值
buffer(0) = buffer.getDouble(0) + input.getDouble(0)
//计算次数
buffer(1) = buffer.getDouble(1) + 1
}
}
//全局聚合 ,将分区内计算的数据再聚合在一起
//buffer1 存的是最终全局聚合的数据值 buff2 是对应每个分区计算结果值
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//全局聚合总工资
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
//全局聚合总次数
buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
}
//最终计算结果
override def evaluate(buffer: Row): Double = {
buffer.getDouble(0) / buffer.getDouble(1)
}
}