scala实现UDF及UDAF案例

首先初始化一个Dataset

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.{Dataset, Row, SparkSession, types}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._

import java.text.SimpleDateFormat
case class Order(detail_id:String,order_id:String,product_id:String,num:Int,amt:Double,order_time:String)

object udf_udaf_udtf {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("udf-udaf-udtf")
                            .master("local[*]").enableHiveSupport().getOrCreate()
    import spark.implicits._

    val arr0=Array(("d001","o001","jacket001",1,200.2,"2022-07-20 10:25:19"),
                   ("d002","o001","shoe001",1,305.3,"2022-07-20 10:25:21"),
                   ("d003","o002","tshirt001",2,158.8,"2022-07-21 10:26:01"),
                   ("d004","o003","skirt001",1,88.6,"2022-07-21 18:26:32"),
                   ("d005","o003","skirt002",1,62.1,"2022-07-21 18:27:02"),
                   ("d006","o004","shoe001",1,305.3,"2022-07-22 12:27:33"),
                   ("d007","o005","jacket001",2,400.4,"2022-07-25 15:19:58"),
                   ("d008","o005","shoe001",2,610.6,"2022-07-25 15:19:59"),
                   ("d009","o006","skirt002",2,124.2,"2022-07-27 13:09:55"),
                   ("d010","o007","skirt001",1,88.6,"2022-07-27 14:06:55"))
    //创建dataset
    val orderDS = spark.sparkContext.parallelize(arr0).map({
                  line=>Order(line._1.toString,line._2.toString,line._3.toString,
                              line._4.toInt,line._5.toDouble,line._6.toString) }).toDS()

UDF注册及使用:

//定一个udf转换函数,标准日期时间格式转时间戳
    def datetimeTrans(datetime:String):String = {
      val fm = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
      val dt = fm.parse(datetime)
      val tim: Long = dt.getTime
      tim.toString
    }

    /*==============================UDF的注册和使用方式======================================*/
    spark.udf.register("datetimeTrans",datetimeTrans(_:String))
    /*另一种注册udf的方式也可采用如:
    * spark.udf.register("hob_num",(s:String)=>s.split(",").size)
    * */
    orderDS.createOrReplaceTempView("sales")
    spark.sql("select detail_id,order_id,datetimeTrans(order_time) as tstamp from sales limit 10").show()

运行结果:

+---------+--------+-------------+
|detail_id|order_id|       tstamp|
+---------+--------+-------------+
|     d001|    o001|1658283919000|
|     d002|    o001|1658283921000|
|     d003|    o002|1658370361000|
|     d004|    o003|1658399192000|
|     d005|    o003|1658399222000|
|     d006|    o004|1658464053000|
|     d007|    o005|1658733598000|
|     d008|    o005|1658733599000|
|     d009|    o006|1658898595000|
|     d010|    o007|1658902015000|
+---------+--------+-------------+

UDAF实现多行转一行的聚合效果,实现过程需要对官方的抽象类UserDefinedAggregateFunction进行详细的逻辑定义,如下:

/* 对官方抽象类UserDefinedAggregateFunction进行重写,定义udaf的逻辑
 * 求均价sum(amt)/sum(num)  及 聚合维中出现的最大数量max(num)
 **/
class MySumFunction extends UserDefinedAggregateFunction{   //自定义函数
  //step1:定义输入数据的schema
  override def inputSchema: StructType = {
    //new StructType().add("age",LongType)
    StructType(Seq(
      StructField("num",IntegerType),  //数量num 整型
      StructField("amt",DoubleType)    //金额amt double
    ))
  }
  //step2:定义缓存bufferSchema
  //根据计算需求,要缓存sum(数量)  sum(金额) 和 当前最大数量
  override def bufferSchema: StructType = {
    new StructType().add("sum_num",IntegerType)
                    .add("sum_amt",DoubleType)
                    .add("max_num",IntegerType)
  }
  //step3:定义聚合函数返回的数据结构
  //override def dataType: DataType = DoubleType
  override def dataType: DataType = ArrayType(DoubleType)

  //step4:初始化缓存
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.00
    buffer(2) = 0
  }

  //step5:定义每进一条数据,对缓存的计算规则
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0)+input.getInt(0)                 //num累加计算
    buffer(1) = buffer.getDouble(1)+input.getDouble(1)           //amt累加计算
    buffer(2) = (if (buffer.getInt(2)>=input.getInt(0)) buffer.getInt(2) else input.getInt(0))  //找寻最大值逻辑
  }

  //step6:各缓存合并的逻辑
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)        //num累计
    buffer1(1) = buffer1.getDouble(1)+buffer2.getDouble(1)  //amt累计
    buffer1(2) = (if (buffer1.getInt(2)>=buffer2.getInt(2)) buffer1.getInt(2) else buffer2.getInt(2)) //最大值逻辑
  }

  //step7:计算最终结果
  override def evaluate(buffer: Row): Seq[Double] = {
    //(buffer.getLong(0)/buffer.getLong(1)).toDouble      //返回单个数据的形式
    Seq(buffer.getDouble(1)/buffer.getInt(0),  // 销售均价
        buffer.getInt(2))   //最大销售数量
  }

  //聚合函数 相同的输入是否要相同的输出,聚合函数是否幂等操作
  override def deterministic: Boolean = true
}

再定义注册及使用过程:

/*==============================UDAF的注册和使用方式====================================*/
    spark.udf.register("mySum",new MySumFunction)
    //sparkSQL的方式使用udaf
    spark.sql(
      """
        |select a.product_id as product_id,
        |       a.ms[0] as avg_price, --商品均价
        |       a.ms[1] as max_num    --商品单次最大售卖数量
        |from
        |(select product_id,mySum(num,amt) as ms
        |from sales group by product_id ) a
        |limit 100
        |""".stripMargin).show()
    /*
    * */
    //dataset方式处理数据,并将集合拆分
    val mySum = new MySumFunction
    val ds1 = orderDS.groupBy("product_id").agg(mySum(col("num"),col("amt")).as("ms"))
    val ds2 = ds1.withColumn("product_id",col("product_id"))
                 .withColumn("avg_price",$"ms".getItem(0))
                 .withColumn("max_num",$"ms".getItem(1))
    ds2.show()

两种处理方式的结果:

sparksql UDAF结果:
+----------+------------------+-------+
|product_id|         avg_price|max_num|
+----------+------------------+-------+
|   shoe001|             305.3|    2.0|
|  skirt002|              62.1|    2.0|
|  skirt001|              88.6|    1.0|
| tshirt001|              79.4|    2.0|
| jacket001|200.19999999999996|    2.0|
+----------+------------------+-------+
ds2结果:
+----------+--------------------+------------------+-------+
|product_id|                  ms|         avg_price|max_num|
+----------+--------------------+------------------+-------+
|   shoe001|        [305.3, 2.0]|             305.3|    2.0|
  |  skirt002|         [62.1, 2.0]|              62.1|    2.0|
  |  skirt001|         [88.6, 1.0]|              88.6|    1.0|
  | tshirt001|         [79.4, 2.0]|              79.4|    2.0|
  | jacket001|[200.199999999999...|200.19999999999996|    2.0|
  +----------+--------------------+------------------+-------+

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值