UDF
- UDF(User Define Function)
spark内置的函数不能满足我们的要求的时候,我们通常需要自定义函数来实现我们的需求。 - 示例
object UDF {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder()
.appName("UDF")
.master("local[2]")
.getOrCreate()
// 创建一个RDD
val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "Alan")
val namesRDD = sparkSession.sparkContext.parallelize(names)
// 转为DF
// 方法一:动态加载
val namesRowRDD = namesRDD.map(x => Row(x))
val schema = StructType(Array(
StructField("name", StringType, true)
))
val namesDF = sparkSession.createDataFrame(namesRowRDD, schema)
// 方法二:反射
import sparkSession.implicits._
val namesDF = namesRDD.toDF("name")
namesDF.createOrReplaceTempView("udfTest")
// 注册UDF函数
sparkSession.udf.register("strLength", (str:String) =>str.length)
sparkSession.sql("select name, strLength(name) length from udfTest").show()
}
}
+--------+------+
| name|length|
+--------+------+
|zhangsan| 8|
| lisi| 4|
| wangwu| 6|
| Tom| 3|
| Jerry| 5|
| Alan| 4|
+--------+------+
UDAF
- UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值按指定方法聚合一下。
- UDAF的使用有两种方式:1.继承UserDefinedAggregateFunction 2.继承Aggregator(优点是可以带类型)
使用UDAF一般步骤:
- 自定义类继承UserDefinedAggregateFunction或者Aggregator,对每个阶段方法做实现
- 在sparkSession中注册UDAF,为其绑定一个名字
- 在sql语句中使用上面绑定的名字调用
- 继承UserDefinedAggregateFunction
class UDAFStringCount extends UserDefinedAggregateFunction {
// 输入数据的类型
override def inputSchema: StructType = {
StructType(Array(StructField("str", StringType, true)))
}
// 中间聚合时所处理的数据
override def bufferSchema: StructType = {
StructType(Array(StructField("count", IntegerType, true)))
}
// 函数返回的类型
override def dataType: DataType = {
IntegerType
}
// 指定是否是确定性的
override def deterministic: Boolean = {
true
}
// 为每个分组的数据执行初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
// 每个分组有新值过来,如何进行分组对应的聚合值的计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
// 合并,一个分组的数据会分布在多个节点上处理,所以最后要用merge
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
// 通过中间的缓存聚合值,最后返回一个最终的聚合值
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
}
- 注册udaf,sql调用
object UDAF {
def main(args: Array[String]): Unit = {
//UDAF可以针对多行输入,进行聚合计算,返回一个输出
val sparkSession = SparkSession.builder()
.appName("UDAF")
.master("local[2]")
.getOrCreate()
val names = Array("zhangsan", "lisi", "wangwu", "Tom", "Jerry", "zhangsan", "Tom", "zhangsan",
"lisi", "wangwu", "Tom", "Jerry", "Alan")
val namesRDD = sparkSession.sparkContext.parallelize(names)
import sparkSession.implicits._
val namesDF = namesRDD.toDF("name")
namesDF.createOrReplaceTempView("udafTest")
sparkSession.udf.register("strCount", new UDAFStringCount)
sparkSession.sql("select name, strCount(name) len from udafTest group by name").show()
}
}
+--------+---+
| name|len|
+--------+---+
| wangwu| 2|
| Tom| 3|
| Jerry| 2|
|zhangsan| 3|
| Alan| 1|
| lisi| 2|
+--------+---+
- 继承Aggregator
在这里插入代码片