Spark自定义UDF和UDAF函数

前言:
最然spark提供的算子和函数非常丰富,但是对于一些特殊的业务需求,还是自定义函数比较好用,自定义函数一般有UDF和UDAF,UDF就是一对一的函数,UDAF是一对多函数

一:自定义UDF函数

步骤:
1.自定义方法–满足需求
2.注册方法
3.在sql语句中使用函数

//需求:正常情况下员工的工号是8位,现在工号长度不够,需要使用0来补全
object demo03 {
  private val spark = SparkSession.builder().master("local[6]").appName("aggregation").getOrCreate()
  import spark.implicits._

  //创建数据
  private val rdd: RDD[(String, String)] = spark.sparkContext.parallelize(Seq( ("010","张三"),("020","李四"),( "003","王五")))
  //把数据转成表的形式
  rdd.toDF("id","name").createTempView("employee")

  //自定义UDF函数,该函数的功能是补全员工id的位数,所以参数是id,返回值类型是String
  def fillID(id:String): String ={
     "0"*(8-id.length)+id
  }

  def main(args: Array[String]): Unit = {
    //注册自定义的UDF函数,参数1:函数的名称 , 参数2:自定义的方法转成函数的形式
    spark.udf.register("fillID", fillID _)

    //验证自定义函数
    spark.sql("select fillID(ID),name from employee").show()
  }
}

结果:

+--------------+----+
|UDF:fillID(ID)|name|
+--------------+----+
|      00000010|  张三|
|      00000020|  李四|
|      00000003|  王五|
+--------------+----+

二:自定义UDAF函数

步骤:
1.创建类继承UserDefinedAggregateFunction类
2.重写方法
3.测试自定义函数

//需求:计算员工平均工资
object demo04 {
  //1.创建类去继承
  class UdafFunction extends UserDefinedAggregateFunction{
    //2.重写8个方法

    /**
      * 聚合函数输入参数的类型--返回值是StructType类型
      * @return
      */
    override def inputSchema: StructType = new StructType().add("input",IntegerType)

    /**
      * 缓冲区:在计算过程中需要用到的中间变量的类型
      *  需要用到两个中间变量: 一个为sum【用来统计价格的总和】  一个为total【用来记录商品的个数】
      * @return
      */
    override def bufferSchema: StructType = {
      new StructType().add("sum",IntegerType).add("total",IntegerType)
    }

    /**
      * 指明返回值类型
      * @return
      */
    override def dataType: DataType = DoubleType

    /**
      * 是否保存数据的一致性,一般设为true
      * @return
      */
    override def deterministic: Boolean = true

    /**
      * 初始化缓冲区,设置sum=0,total=0
      * @param buffer
      */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //初始化sum
      buffer(0)=0
      //初始化total
      buffer(1)=0
    }

    /**
      * 更新缓冲区的值
      *    每进来一条数据,sum需要累加,total需要+1
      * @param buffer
      * @param input
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //sum进行累加
      buffer(0) = buffer.getAs[Int](0)+input.getAs[Int](0)
      //total自增1
      buffer(1) = buffer.getAs[Int](1)+1

    }

    /**
      * 合并缓冲区
      *  将所有的缓冲区的sum与total值继续累加
      * @param buffer1
      * @param buffer2
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      //累加sum
      buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
     //累加total
      buffer1(1) = buffer1.getAs[Int](1)+buffer2.getAs[Int](1)
    }

    /**
      * 返回最终结果
      * @param buffer
      * @return
      */
    override def evaluate(buffer: Row): Any = {
      buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
    }
  }

  def main(args: Array[String]): Unit = {
    //创建入口
    val spark = SparkSession.builder().master("local[6]").appName("aggregation").getOrCreate()
    //导入隐式转换
    import spark.implicits._

    //创建数据
    val rdd: RDD[(String, Int)] = spark.sparkContext.parallelize(Seq(("huawei",1000),("thinkpad",2000),("redmi",600)))
    //创建临时表
    rdd.toDF("name","price").createTempView("product")

    //注册自定义UDAF函数
    /**
      * 源码
      *  def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
      *  参数1:名称,随便定义,用于使用函数的时候
      *  参数2:UserDefinedAggregateFunction类型 , 是自定义的类的父类,所以new UdafFunction即可
      */
    spark.udf.register("UdafFunction",new UdafFunction)

    //使用自定义UDAF函数
    spark.sql("select UdafFunction(price) from product").show()

  }
}

结果:

+-------------------+
|udaffunction(price)|
+-------------------+
|             1200.0|
+-------------------+

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值