聚合
DataFrames可以提供共同聚合,例如count(),countDistinct(),avg(),max(),min()等。虽然这些功能是专为DataFrames,星火SQL还拥有类型安全的版本,在其中的一些 斯卡拉和 Java的使用强类型数据集的工作。此外,用户不限于预定义的聚合函数,并且可以创建自己的聚合函数。
无用户定义的聚合函数
扩展UserDefinedAggregateFunction 抽象类以实现自定义无类型聚合函数。
例如,用户定义的平均值:
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
object MyAverage extends UserDefinedAggregateFunction {
// 集合函数的输入类型
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// 集合缓冲区的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
//返回值的数据类型
def dataType : DataType = DoubleType
//确定是否返回相同的输出
def deterministic: Boolean = true
//初始化给定的聚合缓冲区。缓冲区本身是一个`Row`,除了标准方法之外,比如在索引处检索值(例如,get(),getBoolean()),提供了更新其值的机会。请注意,缓冲区内的数组和映射仍然是不可变的。
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 更新数据到指定的聚合缓冲区`buffer
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 合并两个聚合缓冲剂和存储更新的缓冲器值回`buffer1`
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
// 注册访问的函数
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
用户定义聚合函数
强类型数据集的用户定义聚合围绕Aggregator抽象类。
例如,类型安全的用户定义平均值所示:
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
// 此聚合的零值。应该满足任何b + zero = b
def zero: Average = Average(0L, 0L)
// 组合两个值以产生新值。为了提高性能,该函数可以修改`buffer`
//并返回它而不是构造一个新的对象
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// 合并两个中间值
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 转换减少
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// 指定中间件类型
def bufferEncoder: Encoder[Average] = Encoders.product
// 为最终输出值类型指定编码器
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
// 将函数转换为`TypedColumn`并给它命名
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+