一、UDF
package com.zgm.sc.day14
import org.apache.spark.sql.SparkSession
/**
* 用udf实现字符串拼接
*/
object UDFDemo1 {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UDFDemo1")
.master("local")
.getOrCreate()
val df = spark.read.json("dir/people.json")
// 注册函数,注册后正在整个应用中都可以用
spark.udf.register("newname", (x: String) => "name:" + x)
df.createOrReplaceTempView("person")
spark.sql("select newname(name) as new_name from person").show()
spark.stop()
}
}
//运行结果:
+------------+
| new_name|
+------------+
|name:Michael|
| name:Andy|
| name:Justin|
+------------+
二、UDAF
用户自定义聚合函数
1、UDAF函数支持DataFrame(弱类型)
过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。
ps:弱类型指的是在编译阶段是无法确定数据类型的,而是在运行阶段才能创建类型
package com.qf.gp1921.day13
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
/**
* 使用UDAF操作DataFrame
* 需求:用UDAF统计员工的平均薪资
*/
class UDAFDemo1 extends UserDefinedAggregateFunction {
// 指定输入类型
override def inputSchema: StructType = StructType(Array(StructField("salary", DoubleType, true)))
// 缓冲的作用是将上次的结果和这次传进来的结果进行聚合,指定缓冲(分区)类型和聚合过程
override def bufferSchema: StructType =
StructType(StructField("sum", DoubleType) :: StructField("count", DoubleType) :: Nil)
// 返回类型
override def dataType: DataType = DoubleType
// 如果给true,有相同的输入,该函数就有相同的输出
// 如果输入的数据有不同的情况,比如每次数据有不同的时间或有不同的数据对应的offset,
// 这时候得到的结果可能就不一样,这个值就设置为false
override def deterministic: Boolean = true
// 初始化方法,对buffer中的数据进行初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 员工的总薪资
buffer(0) = 0.0
// 员工的人数
buffer(1) = 0.0
}
// 局部聚合,分区内的聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// buffer是指之前的结果,input是这次传进来的数据,将要和buffer进行聚合
if (!input.isNullAt(0)) {
// 聚合薪资
buffer(0) = buffer.getDouble(0) + input.getDouble(0)
// 聚合人数
buffer(1) = buffer.getDouble(1) + 1
}
}
// 全局聚合,分区和分区的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 合并薪资
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
// 合并人数
buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
}
// 最终的结果,可以在该方法中对结果进行再次处理
override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getDouble(1)
}
object UDAFDemo1 {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UDFDemo1")
.master("local")
.getOrCreate()
spark.udf.register("aggrfunc", new UDAFDemo1)
val df = spark.read.json("dir/employees.json")
df.createOrReplaceTempView("employees")
df.show()
val res = spark.sql("select aggrfunc(salary) as avgsalary from employees")
res.show()
spark.stop()
}
}
//运行结果
+------+
|avgsalary|
+------+
|3750.0|
+------+
2.UDAF函数支持DataSet(强类型)
通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资
ps:在编译阶段就确定了数据类型
package com.qf.gp1921.day13
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}
/**
* udaf操作DataSet
*/
// 生成DataSet的类型
case class Employee(name: String, salary: Double)
// 作为缓冲的类型
case class AvgSalary(var sum: Double, var count: Double)
class UDAFDemo2 extends Aggregator[Employee, AvgSalary, Double]{
// 初始化方法,初始化每个buffer
override def zero: AvgSalary = AvgSalary(0.0, 0.0)
// 局部聚合
override def reduce(buffer: AvgSalary, employee: Employee): AvgSalary = {
buffer.sum += employee.salary // 聚合薪资
buffer.count += 1 // 聚合人数
buffer
}
// 全局聚合
override def merge(b1: AvgSalary, b2: AvgSalary): AvgSalary = {
b1.sum += b2.sum // 分区和分区的聚合,聚合薪资
b1.count += b2.count // 聚合人数
b1
}
// 计算结果
override def finish(reduction: AvgSalary): Double = reduction.sum / reduction.count
// 设置中间值的编码, 用的编码和Tuple和case是一样的
override def bufferEncoder: Encoder[AvgSalary] = Encoders.product
// 设置最终结果的编码
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object UDAFDemo2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UDFDemo2")
.master("local")
.getOrCreate()
import spark.implicits._
val ds: Dataset[Employee] = spark.read.json("dir/employees.json").as[Employee]
ds.show
// 指定某个列并调用udaf
val avgsalary: TypedColumn[Employee, Double] = new UDAFDemo2().toColumn.name("avg_salary")
val res: Dataset[Double] = ds.select(avgsalary)
res.show
spark.stop()
}
}
//运行结果
+------+
|avg_salary|
+------+
|3750.0|
+------+