spark中自定义udf,udaf函数

自定义函数

类型

	- UDF:一进一出
	- UDAF:多进一出

UDF

流程

spark-sql中SQL的用法
  • 1、自定义udf函数/类(类要注意需要序列化)
  • 2、注册spark.udf.register(“名称”,自定义的函数/自定义的类 _)
  • 3、调用查询方法
自定义udf函数并调用
import org.apache.spark.sql.SparkSession
import org.junit.Test

/**
  * @ClassName: MyUDFdemo
  * @Description: 将员工中id不满8位的补齐
  * @Author: kele
  * @Date: 2021/2/1 20:56
  **/

/**
  * 1、自定义udf函数/类(类要注意需要序列化)
  * 2、注册spark.udf.register("名称",自定义的函数/自定义的类 _)
  * 3、调用查询方法
  */
class MyUDFdemo extends Serializable{

  @Test
  def emp_info={

    val spark = SparkSession.builder().master("local[4]").appName("UDFdemo").getOrCreate()

    import spark.implicits._    //rddtoDF的隐式转换

    val rdd1 = spark.sparkContext.parallelize(List(
      ("00123","zhangsan"),
      ("256","lisi"),
      ("0135","wangwu"),
      ("000368","qianqi"),
      ("00378","zhaoliu")
    ))

    val df = rdd1.toDF("id","name")

    /**
      * 方式一:通过sql的方式查询 自定义函数
      *
      */

//    df.createOrReplaceTempView("user")
//    spark.udf.register("fullId",fullUserId)
//    spark.sql("""select fullId(id) from user """).show()

        /**
          * 自定义类,需要序列化
          *
          */
    df.createOrReplaceTempView("user")
    spark.udf.register("fullId2",fullUserIdclass _)
    spark.sql("""select fullId2(id) from user """).show()

    /**
      * 方式二:selectExpr的方式查找
      */
    df.selectExpr("fullId2(id) id").show()

  }

  //自定义udf函数
  val fullUserId = (id : String)=>{
    s"${"0" *(8-id.length)}${id}"
  }

  //自定义udf类
  def fullUserIdclass(id:String) ={
    s"${"0" *(8-id.length)}${id}"
  }
}
spark-sql中DataFram中的用法

在spark的DataFram的udf方法和spark sql的名字相同,但是属于不同的类,

import org.apache.spark.sql.functions._

//方法一:注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())

//方法二:注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)

UDAF

UDAF弱类型实现

总体流程

  • 1、继承UserDefinedAggregateFunction( 没有泛型)
  • 2、重写方法
    - 1、指定带统计列表的类型
    - 2、指定中间变量的类型
    - 3、指定函数的返回类型
    - 4、设置稳定性
    - 5、初始化中间变量的值
    - 6、求在一个task中的计算过程
    - 7、求在分区间的计算过程
    - 8、函数的返回值
  • 3、注册spark.udf.register,为其绑定一个名字
自定义UDAF弱类型
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
 import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
 
 /**
   * @ClassName: MyUDAF
   * @Description:
   * @Author: kele
   * @Date: 2021/2/1 16:03
   **/
 class MyUDAF extends UserDefinedAggregateFunction{
 
   /**
     * 指定待统计的数据类型
     * @return 返回值类型是StructType类型,
     */
   override def inputSchema: StructType = new StructType().add("age",IntegerType)
 
   /**
     * 这里是求平均值,需要sum,和num,因此需要两个中间变量
     * 指定中间变量的类型,数据进入是是一个个进
     * @return
     */
   override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
     .add("num",IntegerType)
   /**
     * 函数的返回类型
     * @return
     */
   override def dataType: DataType = DoubleType
 
   /**
     * 稳定性,同一组数据输入是否返回相同的值
     * @return
     */
   override def deterministic: Boolean = true
 
   /**
     * 初始化buffer的值
     * @param buffer
     */
   override def initialize(buffer: MutableAggregationBuffer): Unit = {
 
     buffer.update(0,0)
     buffer.update(1,0)
   }
 
   /**
     * 在一个task中的计算过程
     *   sum将age不断累加
     *   count+1
     * @param buffer
     * @param input
     */
   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
 
     buffer.update(0,buffer.getAs[Int](0)+input.getAs[Int](0))
     buffer.update(1,buffer.getAs[Int](1)+1)
   }
 
   /**
     * 分区间的计算方式
     * @param buffer1
     * @param buffer2
     */
   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
 
     buffer1.update(0,buffer1.getAs[Int](0)+buffer2.getAs[Int](0))
     buffer1.update(1,buffer1.getAs[Int](1)+buffer2.getAs[Int](1))
 
   }
 
   /**
     * 返回值
     * @param buffer
     * @return
     */
   override def evaluate(buffer: Row): Any = buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
 }
调用过程
  /**
    * 调用弱类型
    */
  @Test
  def avg_Age={

    val spark = SparkSession.builder()
                                    .master("local[4]")
                                        .appName("avg_age")
                                              .getOrCreate()

    val rdd = spark.sparkContext.parallelize(List(
      ("zhangsan",20,"开发部"),
      ("wanwu",25,"产品部"),
      ("aa",26,"开发部"),
      ("lisi",40,"开发部"),
      ("bb",30,"产品部"),
      ("cc",28,"产品部")
    ))

    import spark.implicits._

    val df = rdd.toDF("name","age","dept")


    df.createOrReplaceTempView("user")

    spark.udf.register("myavg",new MyUDAF)

    spark.sql(
      """
        |select myavg(age) from user group by dept
      """.stripMargin).show()
  }

UDAF强类型实现过程

  • 1、自定义class继承Aggregator[统计的列的类型、中间变量类型,输出结果类型]

  • 2、重写方法

    • 1、初始化中间变量
    • 2、每一个task中的统计过程
    • 3、分区间计算过程
    • 4、计算最终结果并返回
    • 5、编码中间变量的类型,个人认为是为了保证中间数据传输
      注意样例类的父类是product
  • 3、注册spark.udf.register(函数名,udaf(自定义udaf对象))

      		- import org.apache.spark.sql.functions._         //必须调用该隐式转换,否则无法导入
    
自定义强类型
 package com.atguigu.day05
 
 import org.apache.spark.sql.{Encoder, Encoders}
 import org.apache.spark.sql.expressions.Aggregator
 
 /**
   * @ClassName: MyUDAF2
   * @Description: 强类型自定义类,Aggregator可以自定义泛型[输入类型,中间变量,输出类型]
   * @Author: kele
   * @Date: 2021/2/1 16:05
   **/
 
 /**
   * 如果需要多个中间变量,可以考虑使用样例类
   *
   */
 
 case class InterVari(var sum:Int,var count:Int)
 class MyUDAFStrong extends Aggregator[Int,InterVari,Double]{
 
   /**
     * 初始化中间变量
     * @return
     */
   override def zero: InterVari = InterVari(0,0)
 
   /**
     * 每一个task中的统计过程
     * @param b
     * @param a
     * @return
     */
   override def reduce(b: InterVari, a: Int): InterVari = {
 
     b.sum = b.sum+a
     b.count = b.count+1
     b
   }
 
   /**
     * 分区间计算过程
     * @param b1
     * @param b2
     * @return
     */
   override def merge(b1: InterVari, b2: InterVari): InterVari = {
     b1.sum = b1.sum + b2.sum
     b1.count = b1.count + b2.count
     b1
   }
 
   /**
     * 最终结果返回
     * @param reduction
     * @return
     */
   override def finish(reduction: InterVari): Double = reduction.sum.toDouble/reduction.count
 
   /**
     * 编码中间变量的类型,个人认为是为了保证中间数据传输
     * @return  样例类的父类是product
     */
   override def bufferEncoder: Encoder[InterVari] = Encoders.product
 
   /**
     * 编码结果值的类型,个人认为是为了保证中间数据传输
     * @return
     */
   override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
 }
调用过程
  /**
    * 使用弱类型
    */
  @Test
  def avg_Age={

    val spark = SparkSession.builder()
                                    .master("local[4]")
                                        .appName("avg_age")
                                              .getOrCreate()

    val rdd = spark.sparkContext.parallelize(List(
      ("zhangsan",20,"开发部"),
      ("wanwu",25,"产品部"),
      ("aa",26,"开发部"),
      ("lisi",40,"开发部"),
      ("bb",30,"产品部"),
      ("cc",28,"产品部")
    ))

    import spark.implicits._

    val df = rdd.toDF("name","age","dept")


    df.createOrReplaceTempView("user")

    spark.udf.register("myavg",new MyUDAF)

    spark.sql(
      """
        |select myavg(age) from user group by dept
      """.stripMargin).show()
  }

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值