Spark UDAF

使用很少,绝大部分udaf可以用更简洁的 udf 代替

import org.apache.log4j.Logger
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataTypes, StringType, StructField, StructType}

object UDAFTest {

  private val logger = Logger.getLogger(this.getClass)

  val spark: SparkSession = SparkSession
    .builder()
    .appName("local-test")
    .master("local[4]")
    //.enableHiveSupport()
    //.config("spark.shuffle.service.enabled", true)
    //.config("spark.driver.maxResultSize", "4G")
    //.config("spark.sql.parquet.writeLegacyFormat", true)
    .getOrCreate()

  import spark.implicits._

  spark.sparkContext.setLogLevel("warn")


  def main(args: Array[String]): Unit = {
    val t1 = System.currentTimeMillis()
    val source = Array(
      ("11111111", "2021-03-29 00:00:00", "3"),
      ("11111111", "2021-03-29 00:01:00", "5"),
      ("11111111", "2021-03-29 00:02:00", "7"),
      ("11111111", "2021-03-29 00:03:00", "9"),
      ("11111111", "2021-03-29 00:04:00", "11"),
      ("22222222", "2021-03-29 00:00:00", "3"),
      ("22222222", "2021-03-29 00:01:00", "5"),
      ("22222222", "", "7"),
      ("22222222", "2021-03-29 00:03:00", "9"),
      ("22222222", "2021-03-29 00:04:00", "11")
    )

    val rawRDD = spark.sparkContext.parallelize(source)
    val rawDF = rawRDD.map(x => {
      (x._1, x._2, x._3)
    }).toDF("use_id", "occur_time", "value")
      .selectExpr(
        "use_id",
        "occur_time",
        "cast(value as double) as value"
      )

    rawDF.printSchema()
    rawDF.show(false)

    testMaxWhen(rawDF)

    val t2 = System.currentTimeMillis()
    logger.warn("======== run time is: " + ((t2 - t1) / 1000) )
    spark.stop()

  }

  private def testMaxWhen(df: DataFrame) = {
    spark.udf.register("max_when", new MaxWhenUdaf)


    val tmpDF = df.groupBy("use_id")
      .agg(
        callUDF("max_when", $"value", lit(11)).as("max_when_test")
      )

    tmpDF.printSchema()
    tmpDF.show(false)


  }


  class MaxWhenUdaf extends UserDefinedAggregateFunction {
    // 输入参数的数据类型
    override def inputSchema = {

      //DataTypes.createStructField("value_v", DataTypes.StringType, true)

      //StructType(Seq(StructField("value_v", StringType)))

      //StructType(StructField("value_v", StringType) :: Nil)

      new StructType()
        .add("value_v", "double")
        .add("value_tag", "double")

    }

    // buffer中的数据类型
    override def bufferSchema = {
      new StructType()
        .add("buffer_v", "double")

    }

    // 返回值的类型
    override def dataType = {
      DataTypes.StringType
    }

    // 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
    override def deterministic = {
      true
    }

    // 初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = Double.MinValue
    }

    /**
     *
     * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
     * buffer.getAs[Double](0)获取的是上一次聚合后的值
     * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
     * 大聚和发生在reduce端.
     * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
     */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      val calValue = input.getAs[Double](0)
      val calFlag = input.getAs[Double](1)
      if(calValue != null && calFlag != null && !"".equals(calValue) && !"".equals(calFlag)){
        if(calValue != calFlag
          && calValue> buffer.getAs[Double](0) ) {
          buffer.update(0, calValue)
        }
      }
    }

    /**
     * 合并其他部分结果
     * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
     * 也可以是一个节点里面的多个executor合并
     */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      if(buffer1.getDouble(0) > buffer2.getDouble(0) ){
        buffer1.update(0, buffer1.getDouble(0))
      }else{
        buffer1.update(0, buffer2.getDouble(0))
      }

    }

    // 计算逻辑
    override def evaluate(buffer: Row) = {
      buffer.getDouble(0).toString
    }
  }

}

另外这几篇也不错:

https://zhuanlan.zhihu.com/p/25587189

https://blog.csdn.net/kwu_ganymede/article/details/50462020

https://www.jianshu.com/p/ca3ce8baeffb

https://blog.csdn.net/fengfengchen95/article/details/88681780

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值