【Spark原理系列】自定义聚合函数 UserDefinedAggregateFunction 原理用法示例源码分析

Spark 自定义聚合函数(UDAF)UserDefinedAggregateFunction 原理用法示例源码分析

原理

UserDefinedAggregateFunction 是 Spark SQL 中用于实现用户自定义聚合函数(UDAF)的抽象类。通过继承该类并实现其中的方法,可以创建自定义的聚合函数,并在 Spark SQL 中使用。

UserDefinedAggregateFunction原理是基于 Spark SQL 的聚合操作流程。当一个 UDAF 被应用到 DataFrame 上时,Spark SQL 会将 UDAF 转化为一个 AggregateExpression 对象,其中包含了对应的 ScalaUDAF 实例和聚合操作类型。然后,Spark SQL 会对数据进行分组、聚合等操作,并调用 UDAF 中的方法来执行具体的聚合逻辑。

在具体实现中,UserDefinedAggregateFunction 提供了一系列方法,如 inputSchemabufferSchemadataType 等,用于定义输入参数的数据类型、缓冲区中值的数据类型以及返回值的数据类型。同时,它还提供了 initializeupdatemergeevaluate 方法,用于初始化聚合缓冲区、更新缓冲区、合并缓冲区以及计算最终结果。此外,UserDefinedAggregateFunction 还提供了 applydistinct 方法,用于创建 Column 对象,方便在 DataFrame 中使用自定义聚合函数。

总的来说,UserDefinedAggregateFunction 通过定义一系列方法,使得用户可以灵活地实现自定义的聚合逻辑,并将其应用到 Spark SQL 的聚合操作中。通过这种方式,用户可以扩展 Spark SQL 中的聚合能力,满足特定的业务需求。

用法

方法名描述
inputSchema返回聚合函数的输入参数的数据类型StructType
bufferSchema返回聚合缓冲区中值的数据类型的 StructType
dataType返回聚合函数的返回值数据类型
deterministic返回布尔值,指示此函数是否是确定性的。
initialize(buffer)初始化给定的聚合缓冲区。
update(buffer, input)使用新的输入数据更新聚合缓冲区。
merge(buffer1, buffer2)合并两个聚合缓冲区。
evaluate(buffer)根据给定的聚合缓冲区计算最终结果
apply(exprs)使用给定的 Column 参数创建一个 Column 对象来调用 UDAF。
distinct(exprs)使用给定的不同值的 Column 参数创建一个 Column 对象来调用 UDAF。
update(i, value)更新可变聚合缓冲区的第 i 个值。

示例

package org.example.spark
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object AverageVecDemo {
  // 创建自定义聚合函数
  class MyAverage extends UserDefinedAggregateFunction {
    // 输入参数的数据类型
    def inputSchema: StructType = new StructType().add("value", DoubleType)

    // 聚合缓冲区中值的数据类型
    def bufferSchema: StructType = new StructType()
      .add("sum", DoubleType)
      .add("count", LongType)

    // 返回值的数据类型
    def dataType: DataType = DoubleType

    // 是否是确定性的
    def deterministic: Boolean = true

    // 初始化聚合缓冲区
    def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0.0 // sum
      buffer(1) = 0L  // count
    }

    // 更新聚合缓冲区
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      if (!input.isNullAt(0)) {
        val value = input.getDouble(0)
        buffer(0) = buffer.getDouble(0) + value
        buffer(1) = buffer.getLong(1) + 1
      }
    }

    // 合并两个聚合缓冲区
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    // 计算最终结果
    def evaluate(buffer: Row): Any = {
      buffer.getDouble(0) / buffer.getLong(1)
    }
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    // 创建一个 DataFrame
    val data = Seq(1.0, 2.0, 3.0, 4.0, 5.0).toDF("value")

    // 注册自定义聚合函数
    spark.udf.register("myAverage", new MyAverage)

    // 使用自定义聚合函数进行聚合操作
    val result = data.selectExpr("myAverage(value) as average")

    result.show()

    spark.stop()
  }
}

//+-------+
//|average|
//+-------+
//|    3.0|
//+-------+

这个示例中,我们创建了一个自定义聚合函数 MyAverage,用于计算输入数据列的平均值。然后,我们将该函数注册到 Spark 的 UDF(用户定义函数)中,并在 DataFrame 中使用 selectExpr 方法调用它进行聚合操作。最后,我们展示了聚合结果。

源码

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.types._

/**
 * 实现用户自定义聚合函数(UDAF)的基类。
 *
 * @since 1.5.0
 */
@InterfaceStability.Stable
abstract class UserDefinedAggregateFunction extends Serializable {

  /**
   * `StructType` 表示此聚合函数的输入参数的数据类型。
   * 例如,如果一个[[UserDefinedAggregateFunction]]期望两个输入参数,
   * 分别是`DoubleType`和`LongType`类型,返回的`StructType`将如下所示:
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * 此`StructType`的字段名称仅用于标识对应的输入参数。用户可以选择名称以标识输入参数。
   *
   * @since 1.5.0
   */
  def inputSchema: StructType

  /**
   * `StructType` 表示聚合缓冲区中值的数据类型。
   * 例如,如果一个[[UserDefinedAggregateFunction]]的缓冲区有两个值
   * (即两个中间值),分别是`DoubleType`和`LongType`类型,
   * 返回的`StructType`将如下所示:
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * 此`StructType`的字段名称仅用于标识对应的缓冲区值。用户可以选择名称以标识输入参数。
   *
   * @since 1.5.0
   */
  def bufferSchema: StructType

  /**
   * [[UserDefinedAggregateFunction]] 返回值的 `DataType`。
   *
   * @since 1.5.0
   */
  def dataType: DataType

  /**
   * 如果此函数是确定性的,则返回true,即给定相同的输入,总是返回相同的输出。
   *
   * @since 1.5.0
   */
  def deterministic: Boolean

  /**
   * 初始化给定的聚合缓冲区,即聚合缓冲区的初始值。
   *
   * 即应用于两个初始缓冲区的合并函数只应返回初始缓冲区本身,即
   * `merge(initialBuffer, initialBuffer)` 应等于 `initialBuffer`。
   *
   * @since 1.5.0
   */
  def initialize(buffer: MutableAggregationBuffer): Unit

  /**
   * 使用来自`input`的新输入数据更新给定的聚合缓冲区`buffer`。
   *
   * 每行输入调用一次此方法。
   *
   * @since 1.5.0
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

  /**
   * 合并两个聚合缓冲区,并将更新后的缓冲区值存储回`buffer1`。
   *
   * 当我们合并两个部分聚合的数据时,会调用此方法。
   *
   * @since 1.5.0
   */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

  /**
   * 根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
   *
   * @since 1.5.0
   */
  def evaluate(buffer: Row): Any

  /**
   * 使用给定的`Column`s作为输入参数创建此UDAF的`Column`。
   *
   * @since 1.5.0
   */
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = false)
    Column(aggregateExpression)
  }

  /**
   * 使用给定的`Column`s的不同值作为输入参数创建此UDAF的`Column`。
   *
   * @since 1.5.0
   */
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = true)
    Column(aggregateExpression)
  }
}

/**
 * 表示可变聚合缓冲区的`Row`。
 *
 * 不建议在Spark之外扩展它。
 *
 * @since 1.5.0
 */
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {

  /** 更新此缓冲区的第i个值。 */
  def update(i: Int, value: Any): Unit
}
gregateExpression)
  }
}

/**
 * 表示可变聚合缓冲区的`Row`。
 *
 * 不建议在Spark之外扩展它。
 *
 * @since 1.5.0
 */
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {

  /** 更新此缓冲区的第i个值。 */
  def update(i: Int, value: Any): Unit
}

参考链接

https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html

  • 19
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: Spark SQL可以通过自定义聚合函数来实现更加灵活的数据处理。自定义聚合函数可以根据具体的业务需求,对数据进行自定义的聚合操作,例如计算平均值、求和、最大值、最小值等。 要实现自定义聚合函数,需要继承Aggregator类,并实现其抽象方法。Aggregator类包含三个泛型参数,分别为输入数据类型、缓冲区数据类型和输出数据类型。在实现Aggregator类时,需要重写其三个方法:zero、reduce和merge。 其中,zero方法用于初始化缓冲区,reduce方法用于对输入数据进行聚合操作,merge方法用于合并不同分区的缓冲区数据。最后,还需要实现finish方法,用于将缓冲区中的数据转换为输出数据。 完成自定义聚合函数的实现后,可以通过Spark SQL的API将其注册为UDAF(User-Defined Aggregate Function),并在SQL语句中使用。 例如,假设需要计算某个表中某个字段的平均值,可以先定义一个自定义聚合函数: ``` import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.Encoder case class AvgBuffer(var sum: Double = 0.0, var count: Int = 0) class Avg extends Aggregator[Double, AvgBuffer, Double] { def zero: AvgBuffer = AvgBuffer() def reduce(buffer: AvgBuffer, data: Double): AvgBuffer = { buffer.sum += data buffer.count += 1 buffer } def merge(buffer1: AvgBuffer, buffer2: AvgBuffer): AvgBuffer = { buffer1.sum += buffer2.sum buffer1.count += buffer2.count buffer1 } def finish(buffer: AvgBuffer): Double = buffer.sum.toDouble / buffer.count def bufferEncoder: Encoder[AvgBuffer] = Encoders.product def outputEncoder: Encoder[Double] = Encoders.scalaDouble } ``` 然后,将其注册为UDAF: ``` val avg = new Avg spark.udf.register("myAvg", avg) ``` 最后,在SQL语句中使用该自定义聚合函数: ``` SELECT myAvg(salary) FROM employee ``` ### 回答2: Spark SQL是一款开源的分布式计算框架,它支持使用SQL语言进行数据查询和分析,同时可以与Hadoop、Hive等大数据技术进行无缝集成。Spark SQL中的自定义聚合函数,是指用户自己定义一些聚合函数,然后将它们应用到Spark SQL的查询中,从而实现更加灵活和高效的数据分析功能。 在Spark SQL中实现自定义聚合函数,需要遵循以下几个步骤: 1.创建自定义聚合函数类 首先需要创建一个类,该类继承自Aggregator,并实现其中定义的抽象方法。这些方法包括两个泛型:输入类型和累加器类型。输入类型为需要进行聚合的数据类型,累加器类型为处理一个分区的聚合结果类型。 例如,如果我们需要自定义一个计算平均值的聚合函数,那么可以创建一个类如下: class Average extends Aggregator[Double, (Double, Int), Double] { //初始化累加器方法 def zero: (Double, Int) = (0.0, 0) //聚合方法,输入数据类型为Double def reduce(acc: (Double, Int), x: Double): (Double, Int) = (acc._1 + x, acc._2 + 1) //合并累加器方法 def merge(acc1: (Double, Int), acc2: (Double, Int)):(Double, Int) = (acc1._1 + acc2._1, acc1._2 + acc2._2) //输出结果类型为Double类型 def finish(acc: (Double, Int)): Double = acc._1 / acc._2 } 在这个例子中,我们定义了一个计算平均值的聚合函数,其中输入数据类型为Double,累加器类型为一个元组(Double, Int),表示聚合结果的累加器分别包含总和和个数,输出结果类型为Double。 2.注册聚合函数 在创建完自定义聚合函数类后,需要使用SparkSession的udf方法来将它注册为一个UDAF(用户自定义聚合函数)。参看以下代码: val average = new Average().toColumn.name("average") spark.udf.register("average", average) 这里,我们将Average类实例化,然后使用toColumn方法将其转换为一个Column,使用name方法为该列命名为"average"。最后,使用SparkSession的udf方法将该列注册为一个UDAF,命名为"average"。 3.应用聚合函数聚合函数注册完毕后,就可以在查询中使用聚合函数进行数据分析了。参看以下代码: val data = Seq((1, 2.0), (1, 2.0), (2, 3.0), (2, 4.0), (2, 3.0)).toDF("group", "value") data.groupBy("group").agg(expr("average(value)") as "avg").show() //输出如下: //+-----+----+ //|group| avg| //+-----+----+ //| 1| 2.0| //| 2| 3.3| //+-----+----+ 在这个例子中,我们使用了数据帧来模拟一组数据,其中包含group和value两个字段。以下查询语句将数据按照group字段进行分组,并使用预先定义的聚合函数"average"计算每组的平均数。最后,使用show()方法展示查询结果。 总而言之,通过自定义聚合函数,可以为Spark SQL增加更多的聚合功能,从而使数据分析处理更加灵活和高效。 ### 回答3: Spark SQL是一个基于Spark的SQL查询工具,可以将结构化和半结构化数据导入到数据仓库中。在Spark SQL中实现自定义聚合函数非常重要,因为聚合函数是大型数据分析中最重要的部分之一。下面,我们将讨论如何在Spark SQL中实现自定义聚合函数Spark SQL中的聚合函数Spark SQL中,聚合函数是SQL查询语句中用于计算一个数据集中值的函数。这些函数包括最小值,最大值,求和,平均值和计数函数等。 由于Spark SQL是用Scala编写的,因此我们可以在其上下文中定义和使用Scala函数。但是,为了使函数能够在SQL查询中使用,我们需要将它们转换为聚合函数。 定义聚合函数 要定义聚合函数,我们需要定义一个包含聚合函数的类并扩展Aggregator trait。该类必须定义三个类型:输入类型,中间类型和输出类型。 输入类型指的是需要在聚合函数中使用的数据类型。在本例中,我们将使用一个整数类型的输入数据。 中间类型指的是在计算过程中使用的数据类型。这个类型可以是任何类型,只要它们可以相加,并在最后输出结果。在本例中,我们将中间类型定义为一个二元组类型。 输出类型指最终聚合函数的结果类型。因此,我们将输出类型定义为一个double类型的数据。 现在,我们可以定义一个具有以上规则的自定义聚合函数: import org.apache.spark.sql.expressions._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ object MyAggregator extends Aggregator[Int, (Int, Int), Double] { override def zero: (Int, Int) = (0, 0) override def reduce(b: (Int, Int), a: Int): (Int, Int) = (b._1 + a, b._2 + 1) override def merge(b1: (Int, Int), b2: (Int, Int)): (Int, Int) = (b1._1 + b2._1, b1._2 + b2._2) override def finish(r: (Int, Int)): Double = r._1.toDouble / r._2 override def bufferEncoder: Encoder[(Int, Int)] = Encoders.product[(Int, Int)] override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } 解释: zero方法返回一个中间类型的初始值。在这个例子中,我们使用(0, 0)作为初始值。 reduce 方法使用输入类型的值和中间类型的值并返回一个新的中间类型的值。 merge方法将两个中间类型的值合并成一个中间类型的值。 finish方法将最终的中间类型的值转换为输出类型的值。 bufferEncoder和outputEncoder方法分别定义缓冲区类型和输出类型的编码器。 使用自定义函数 一旦自定义聚合函数定义完成,我们可以在SQL查询中使用它。假设我们有以下数据集: +---+ |num| +---+ | 1| | 2| | 3| | 4| | 5| +---+ 我们可以使用以下查询来使用我们的自定义聚合函数并计算平均数: val df = Seq(1, 2, 3, 4, 5).toDF("num") df.agg(MyAggregator.toColumn.name("avg")).show() 输出: +---+ |avg| +---+ |3.0| +---+ 总结 Spark SQL中自定义聚合函数的过程稍微有些困难,但是一旦我们定义了自定义聚合函数,我们就可以将其用作SQL查询中的任何其他聚合函数。而且在使用它时,我们可以拥有无限的灵活性来定义任何形式的自定义聚合函数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值