SparkSQL中自定义聚合(UDAF)函数

用户自定义函数类别分为以下三种:

1).UDF:输入一行,返回一个结果(一对一),在上篇案例 使用SparkSQL实现根据ip地址计算归属地二 中实现的自定义函数就是UDF,输入一个十进制的ip地址,返回一个省份

2).UDTF:输入一行,返回多行(一对多),在SparkSQL中没有,因为Spark中使用flatMap即可实现这个功能

3).UDAF:输入多行,返回一行,这里的A是aggregate,聚合的意思,如果业务复杂,需要自己实现聚合函数

下面就来介绍如何自定义UDAF聚合函数

以一个实际案例来介绍,这个案例是求几何平均数的,几何平均数不知道的可以去百度百科看,简单来说就是求n个数乘积的开n次方,但是如果这里的n很大,在单机上根本就运算不了怎么办,我们可以在Spark集群上执行这个任务

思路:

如图所示: 在集群的机器的分区内执行计算出各自的n和t,然后汇总到一起再执行Math.pow来计算几何平均数

具体代码实现:

package cn.ysjh0014.SparkSql
 
import java.lang
 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
object UDAFTest {
 
  def main(args: Array[String]): Unit = {
 
    val session: SparkSession = SparkSession.builder().appName("UDAFTest").master("local[*]").getOrCreate()
 
 
    val udaf = new UDAFys
 
    //注册函数
//    session.udf.register("udaf",udaf)
     val range: Dataset[lang.Long] = session.range(1, 11)
//    range.createTempView("table")
//    val df = session.sql("SELECT udaf(id) result FROM table")
 
    import session.implicits._
 
    val df = range.agg(udaf($"id").as("geomean"))
    df.show()
 
    session.stop()
  }
}
 
 
class UDAFys extends UserDefinedAggregateFunction {
 
  //输入数据的类型
  override def inputSchema: StructType = StructType(List(
    StructField("value", DoubleType)
  ))
 
  //产生中间结果的数据类型
  override def bufferSchema: StructType = StructType(List(
    //相乘之后返回的积
    StructField("project", DoubleType),
    //参与运算数字的个数
    StructField("Num", LongType)
  ))
 
  //最终返回的结果类型
  override def dataType: DataType = DoubleType
 
  //确保一致性,一般用true
  override def deterministic: Boolean = true
 
  //指定初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //相乘的初始值,这里的要和上边的中间结果的类型和位置相对应
    buffer(0) = 1.0
    //参与运算数字个数的初始值
    buffer(1) = 0L
  }
 
  //每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的计算)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //每有一个数字参与运算就进行相乘(包含中间结果)
    buffer(0) = buffer.getDouble(0) * input.getDouble(0)
    //参与运算的数字个数更新
    buffer(1) = buffer.getLong(1) + 1L
  }
 
  //全局聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //每个分区计算的结果进行相乘
    buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
 
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
 
  //计算最终的结果
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))
  }
}

下面是复杂筛选下的demo

package test.udaf

import org.apache.log4j.Logger
import org.apache.spark.Partition
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, SparkSession}

object UDAFTest2 {

  val logger = Logger.getLogger(this.getClass)

  val spark: SparkSession = SparkSession
    .builder()
    .appName("local-test")
    .master("local[4]")
    //.enableHiveSupport()
    //.config("spark.shuffle.service.enabled", true)
    //.config("spark.driver.maxResultSize", "4G")
    //.config("spark.sql.parquet.writeLegacyFormat", true)
    .getOrCreate()

  spark.sparkContext.setLogLevel("warn")
  import spark.implicits._


  def main(args: Array[String]): Unit = {


    val dataSource = Seq(
      ("11111110", "F2", "100"),
      ("11111111", "F2", "200"),
      ("11111112", "F2", "300"),
      ("11111113", "F2", "400"),
      ("11111114", "F2", "500"),
      ("11111115", "F2", "600"),
      ("11111116", "F2", "700"),
      ("11111117", "F2", "800"),
      ("11111118", "F2", "900"),
      ("11111119", "F2", "1000"),
      ("22222220", "F3", "100"),
      ("22222221", "F3", "200"),
      ("22222222", "F3", "300"),
      ("22222223", "F3", "400"),
      ("22222224", "F3", "500"),
      ("22222225", "F3", "600"),
      ("22222226", "F3", "700"),
      ("22222227", "F3", "800"),
      ("22222228", "F3", "900"),
      ("22222229", "F3", "1000")
    )
    val rawDF: DataFrame = spark.createDataFrame(dataSource).toDF("user", "platform", "fe")
      .withColumn("fe", $"fe".cast(DoubleType))

    val copyDF = rawDF
      .select(
        (rawDF.columns).map(i => col(i).alias(s"${i}2")): _*
      )

    val joinDF = rawDF.join(copyDF, rawDF("platform") === copyDF("platform2"), "inner")
      .filter($"user" =!= $"user2")

    val partitions: Int = joinDF.rdd.getNumPartitions
    val partitions1: Array[Partition] = joinDF.rdd.partitions

    logger.warn("========>>>>>>>> " + partitions)
    partitions1.foreach(e => {
      println(e.index)
    })

    joinDF.printSchema()
    joinDF.show(1000,false)

    val cal_fe_sigma_udf: UserDefinedAggregateFunction = spark.udf.register("cal_fe_sigma", new SigmaUdafTest())

    val tmpDF = joinDF.groupBy("user", "platform")
      .agg(
        cal_fe_sigma_udf($"user", $"fe", $"user2", $"fe2")
      )

    tmpDF.printSchema()
    tmpDF.show(false)

  }


}

 

package test.udaf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

class SigmaUdafTest extends UserDefinedAggregateFunction {

  // 聚合函数的输入数据结构
  //  override def inputSchema: StructType = {
  //    StructType(StructField("esn", StringType)
  //      :: StructField("fe", DoubleType)
  //      :: StructField("fe2", DoubleType)
  //      :: Nil)
  //  }
  override def inputSchema: StructType = {
    new StructType()
      .add("user", StringType)
      .add("fe", DoubleType)
      .add("user2", StringType)
      .add("fe2", DoubleType)
  }

  //缓存数据类型 即在聚合计算过程当中的中间结果数据类型
  override def bufferSchema: StructType = {
    new StructType()
      .add("userArray", DataTypes.createArrayType(StringType))
      .add("feArray", DataTypes.createArrayType(DoubleType))
      .add("user2Array", DataTypes.createArrayType(StringType))
      .add("fe2Array", DataTypes.createArrayType(DoubleType))
  }

  // 聚合函数返回值数据结构
  override def dataType: DataType = {
    StringType
  }

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true

  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Seq[String]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
    buffer.update(1, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
    buffer.update(2, Seq[String]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
    buffer.update(3, Seq[Double]()) // 待计算的数据存入一个序列中,以备后用,初始化为一个空序列
  }

  // 更新缓存的数据,输入一条数据后, 更新到缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getSeq(0) :+ input.getString(0))
    buffer.update(1, buffer.getSeq(1) :+ input.getDouble(1))
    buffer.update(2, buffer.getSeq(2) :+ input.getString(2))
    buffer.update(3, buffer.getSeq(3) :+ input.getDouble(3))
  }

  // 合并两个聚合缓冲区并将更新后的缓冲区值存储回“buffer1” 相当于先局部聚合再全局聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
    buffer1.update(1, buffer1.getSeq[Double](1) ++ buffer2.getSeq[Double](1))
    buffer1.update(2, buffer1.getSeq[String](2) ++ buffer2.getSeq[String](2))
    buffer1.update(3, buffer1.getSeq[Double](3) ++ buffer2.getSeq[Double](3))
  }

  override def evaluate(buffer: Row): Any = {
    val userArrayAll = buffer.getSeq[String](0)
    val feArrayAll = buffer.getSeq[Double](1)
    val user2ArraAll = buffer.getSeq[String](2)
    val fe2ArrayAll = buffer.getSeq[Double](3)

    val user = userArrayAll(0)
    val fe = feArrayAll(0)

    val minFe: Double = (fe - 100d)
    val maxFe: Double = (fe + 100d)

    var lst = Seq[String]()
    for(i <- user2ArraAll.indices) {
      if (minFe <= fe2ArrayAll(i) && fe2ArrayAll(i) <= maxFe) {
        // user2_fe2
        lst = lst :+ (user2ArraAll(i) + "_" + fe2ArrayAll(i))
      }
    }

    // user:user2_fe2,user2_fe2...
    user + ":" + lst.mkString(",")

  }


  def getSortedIndexArray(occurrenceDateTimeArray: Array[String]): ListBuffer[Int] = {
    val lst = ListBuffer[String]()
    for(i <- occurrenceDateTimeArray.indices) {
      lst.append(occurrenceDateTimeArray(i) + "_" + i)
    }
    val lst1 = lst.sorted
    val sortedIndexArray = ListBuffer[Int]()
    for(i <- lst1.indices) {
      sortedIndexArray.append(lst1(i).split("_")(1).toInt)
    }
    sortedIndexArray
  }




}

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值