-
用户自定义的
UDF
-
定义:
UDF(User-Defined-Function)
,也就是最基本的函数,它提供了SQL
中对字段转换的功能,不涉及聚合操作。例如将日期类型转换成字符串类型,格式化字段。 -
用法
object UDFTest { case class Person(name: String, age: Int) def main(args: Array[String]): Unit = { //常见SparkSession val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate() //根据文件获取RDD val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt") /** * 注册一个udf函数, * toString:为自定义函数的引用名, * (str: String) => str + "我是UDF自定义函数":这个是自定义的函数体,它是一个匿名函数 */ sparkSession.udf.register("toString", (str: String) => str + "我是UDF自定义函数") import sparkSession.implicits._ //引入隐式转换 //利用反射将RDD转换成DataFrame val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF() //将DataFrame注册成一张表 personDF.createOrReplaceTempView("person") //利用Spark的SQL来查询数据,其中toString就是我们自定义的UDF函数 sparkSession.sql("select toString(name),age from person").show() } }
-
-
用户自定义的
UDAF
-
定义:
UDAF
函数是用户自定义的聚合函数,为Spark SQL
提供对数据集的聚合功能,类似于max()、min()、count()
等功能,只不过自定义的功能是根据具体的业务功能来确定的。因为DataFrame是弱类型的,DataSet是强类型,所以自定义的UDAF
也提供了两种实现,一个是弱类型的一个是强类型的。 -
弱类型用法,需要继承
UserDefindAggregateFunction
,实现它的方法package com.lyz.sql.udf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object MyCustomUDAF extends UserDefinedAggregateFunction { //:: Nil 作用就是为StructField常见Array集合,并放入进去 def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil) //缓存字段类型,也就是每个分区的共享变量 def bufferSchema: StructType = StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil) //UDF输出数据类型 def dataType: DataType = IntegerType //输入类型和输出类型是否一致 def deterministic: Boolean = true //初始化分区中的共享变量 def initialize(buffer: MutableAggregationBuffer): Unit = { //初始化每个分区上的年龄总和为0 buffer(0) = 0 //初始化每个分区上的人数为0 buffer(1) = 0 } //每个分区中每一条记录,聚合的时候需要调用该方法 def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //将新输入进来的数据一之前合并的结果做聚合操作, //buffer(0)就是上边定义的年龄总和sum,也就是每个分区上的年龄总和 buffer(0) = buffer.getInt(0) + input.getInt(0) //buffer(1)就是上边定义的人的个数count,也就是每个分区上的人个数 buffer(1) = buffer.getInt(1) + 1 } //对分区结果进行合并 def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // buffer1(0)就是所有分区的年龄总和 //buffer1.getInt(0) + buffer2.getInt(0):就是将没分区上的年龄相加 //下标为0的就是年龄总和 buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //buffer(1)就是所有分区的人个数 //buffer1.getInt(1) + buffer2.getInt(1):就是将每个分区人个数聚合在一起, //下标为1就是人的个数 buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) } //最终结算结果 def evaluate(buffer: Row): Any = { buffer.getInt(0) / buffer.getInt(1) } }
package com.lyz.sql.udf import com.lyz.sql.dataframe.DataFrameTest.Person import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} object MyCustomUDAFMain { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate() //根据文件获取RDD val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt") import sparkSession.implicits._ //引入隐式转换 //利用反射将RDD转换成DataFrame val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF() sparkSession.udf.register("myCustomUDAF", MyCustomUDAF) personDF.createOrReplaceTempView("person") /** * 输出结果为:15 */ sparkSession.sql("select myCustomUDAF(age) from person").show() } }
-
强类型用法,需要继承
Aggregate
,实现它的方法。既然是强类型,那么其中肯定涉及到对象的存在package com.lyz.sql.udf import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.expressions.Aggregator //输入 case class Person(name: String, age: Int) //缓存变量,也就是逻辑介质, case class Avg(sum: Int, count: Int) object MyCutomUDAFStrong extends Aggregator[Person, Avg, Int] { //初始化缓存变量 def zero: Avg = Avg(0, 0) /** * 每个分区计算各自的结果 * * @param b :聚合后的缓存变量 * @param a :新输入的数据 * @return b:聚合后的缓存变量 */ def reduce(b: Avg, a: Person): Avg = { b.sum += a.age b.count += 1 b } //合并每个分区的结果 def merge(b1: Avg, b2: Avg): Avg = { b1.sum += b2.sum b1.count += b2.count b1 } //最后完成平均值的计算 def finish(reduction: Avg): Int = { reduction.sum / reduction.count } //Encoders.product:是对scala元组和case类型转换的编码器 def bufferEncoder: Encoder[Avg] = Encoders.product //设定输出值的编码器 def outputEncoder: Encoder[Int] = Encoders.scalaInt }
package com.lyz.sql.udf import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, SparkSession, TypedColumn} object MyCustomStrongMain { def main(args: Array[String]): Unit = { val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate() //根据文件获取RDD val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt") import sparkSession.implicits._ //引入隐式转换 //里用RDD生成Dataset val personDS: Dataset[Person] = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDS() //将这个函数转成TypedColumn,并且提供一个别名 val avgAge: TypedColumn[Person, Int] = MyCustomUDAFStrong.toColumn.name("ageAvg") personDS.select(avgAge).show() } }
-
Spark SQL(二十二)用户自定义的UDF、UDAF函数
最新推荐文章于 2021-10-27 22:08:31 发布