UDF:即用户自定义函数
UDAF:即用户自定义聚合函数
#UDF案例一:
import org.apache.spark.sql.{DataFrame, SparkSession}
//用户自定义函数
object _UdfDemo1 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.appName(" ")
.getOrCreate()
val df: DataFrame = spark.read.json("sql/emp.json")
df.createTempView("emp")
//1.定义一个方法
def fun1(word:String)={
word.length
}
//2.注册函数 fun1 _ :将方法转成函数
// spark.udf.register("mylength",fun1 _)
//不用定义函数,直接使用匿名函数的写法
spark.udf.register("myudf1",{word:String=>word.length})
//3.SQ风格的写法
spark.sql("select ename,myudf1(ename) as lg from emp where myudf1(ename)>4").show()
spark.stop()
}
}
#UDF案例二:
import org.apache.spark.sql.{DataFrame, SparkSession}
//用户自定义函数
object _UdfDemo2 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.appName(" ")
.getOrCreate()
val df: DataFrame = spark.read.json("sql/emp.json")
df.createTempView("emp")
//需求分析:显示每个员工的工资等级,
// sal>3000 显示level3
// sal>1500 显示level2
// 其他的显示level1
//1.自定义一个fun2函数
def fun2(num:Double)={
if (num>3000)
"level3"
else if (num>1500)
"level3"
else
"level1"
}
//2.注册自定义函数
spark.udf.register("myudf2",fun2 _)
//3.写SQL风格语句,使用自定义函数
val sql =
"""
|select
|count(1),
|myudf2(sal)
|from
|emp
|group by
|myudf2(sal)
|""".stripMargin
//4.传入sql,显示数据
spark.sql(sql).show()
spark.stop()
}
}
#UDAF案例一:
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
//用户自定义聚合函数
object _UdafDemo3 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.appName(" ")
.getOrCreate()
val df: DataFrame = spark.read.json("sql/emp.json")
df.createTempView("emp")
/**
* 名词解释:
* udf:一对一关系 传入一行输出一行
* dfaf:多对一关系 传入多行输出一行
*/
//需求:
//查询每个部门的平均工资
//2.注册函数
spark.udf.register("myudaf1",new MyUDAF)
//3.编写sql语句
val sql =
"""
|select
|deptno,
|myudaf1(sal)
|from
|emp
|group by deptno
|""".stripMargin
//4,显示结果
spark.sql(sql).show()
spark.stop()
}
}
/**
* 1.
* 需要自定一个类型,用来继承UDAF的相关类型
* 需要根据需求重写继承类的方法
*/
class MyUDAF extends UserDefinedAggregateFunction{
//--描述进入函数的参数的类型
override def inputSchema: StructType = StructType(
Array(
StructField("num",DoubleType)
)
)
//--描述计算过程中涉及到的变量的类型
override def bufferSchema: StructType = StructType(
Array(
StructField("sum",DoubleType),
StructField("count",LongType)
)
)
//--描述计算结果的类型
override def dataType: DataType = DoubleType
//--描述函数的稳定性
override def deterministic: Boolean = true
//--对计算过程中涉及到的两个变量进行初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//buffer的第一个元素表示sum
buffer(0) = 0D
//buffer的第二个元素表示count
buffer(1) = 0L
}
/**
* 预聚合操作
* @param buffer 参数buffer指的使当前的缓存中的数据,相当this
* @param input 表示刚刚进入函数内的这一条记录
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//update 用于更新数据
buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
//+1操作
buffer.update(1,buffer.getLong(1)+1)
}
/**
* 分区间的合并操作
* @param buffer1 要返回的数据
* @param buffer2 另外一个分区的数据
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
}
//用于计算结果
override def evaluate(buffer: Row): Any ={
buffer.getDouble(0)/buffer.getLong(1)
}
}