SparkSql自定义UDF和UDAF函数

30 篇文章 1 订阅

SparkSql自定义UDF和UDAF函数

package com.spark.sparksql


import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
import org.junit.Test

/**sql可以使用的函数:
 *    1.一进一出:UDF函数
 *    2.多进一出:UDAF函数
 *    3.一进多出:UDTF函数(spark sql没有)
 */
class $03_udf {
  
  val spark = SparkSession.builder().master("local[4]")
    .appName("test").getOrCreate()
  import spark.implicits._
  
  
  /**
   * sparkSql中自定义udf函数步骤:
   *     1.定义一个方法/函数
   *     2.注册定义的方法/函数
   */
  @Test
  def udf:Unit={
  
    val df = List( ("1001","zhangsan",20),("000102","lisi",30),("0123","wangwu",40) )
      .toDF("id","name","age")
    //需求:正常用户id应该为8位,不满8位的id左边以0补齐
    
    //注册函数
    spark.udf.register("preId",prefixUserId _)
    //调用自定义函数
    df.selectExpr("preId(id)","name","age").show()
   
  }
  
  /**
   * 定义补齐id的方法
   * @param id
   * @return
   */
  def prefixUserId(id:String):String={
    //判断id是否有8位
    if (id.length<8){
      //不足补齐
      "0"*(8-id.length) + id
    }else{
      id
    }
  }
  
  
  /**
   * 自定义udaf函数:
   *    1.使用弱类型方式:spark3.0中标记过时
   *        1.定义一个class继承UserDefinedAggregateFunction
   *        2.重写方法
   *        3.注册
   *        4.使用
   *    2.使用强类型方法:spark3.0主推使用
   *       1.定义一个class继承Aggregator
   *       2.重写方法
   *       3.注册
   *       4。使用
   */
  @Test
  def udaf1:Unit={
  
    val df = List( ("1001","zhangsan",20),("000102","lisi",30),("0123","wangwu",40) ).toDF("id","name","age")
  
    //注册弱类型udaf函数
    spark.udf.register("myAvg",new MyAvgAgg)
    df.selectExpr("myAvg(age)").show()
    
    import org.apache.spark.sql.functions._
    //注册强类型udaf函数
    spark.udf.register("myAvg2",udaf(new MyAvgAgg2))
    df.selectExpr("myAvg2(age)").show()
    
  }
  
}

case class Buff(var sum:Int, var count:Int)

/**
 * IN: 要聚集的元素类型
 * BUF: 中间变量的类型
 * OUT: 最终结果类型
 */
class  MyAvgAgg2   extends  Aggregator[Int,Buff ,Double] {
  
  //初始化中间变量
  override def zero: Buff = Buff(0, 0)
  //每个task中聚合元素
  override def reduce(b: Buff, a: Int): Buff = {
    Buff(b.sum + a, b.count + 1)
  }
  //合并所有task的聚合结果
  override def merge(b1: Buff, b2: Buff): Buff = {
    Buff(b1.sum + b2.sum, b1.count + b2.count)
  }
  //获取最终结果值
  override def finish(reduction: Buff): Double = {
    reduction.sum.toDouble / reduction.count
  }
  //spark内部对中间结果编码使用
  //case class全部都是Product类型
  override def bufferEncoder: Encoder[Buff] = Encoders.product[Buff]
  //spark内部对最终结果编码使用
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}


/**
 * 弱类型udaf:
 *   自定义一个udaf,计算平均值
 *       计算平均值需要两个中间变量:sum、count
 */
class MyAvgAgg extends UserDefinedAggregateFunction{
  
  //输入数据类型
  override def inputSchema: StructType = new StructType().add("input",IntegerType)
  //定义计算的时候中间变量的数据类型
  override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
    .add("count",IntegerType)
  //定义最终结果值类型
  override def dataType: DataType = DoubleType
  //设置稳定性,同样的输入是否输出相同
  override def deterministic: Boolean = true
  //初始化中间变量
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化sum值
    buffer.update(0,0)
    //初始化count值
    buffer.update(1,0)
  }
  
  /**
   * 每个task中对数据进行聚合
   * @param buffer 中间结果
   * @param input  输入的数据
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //更新sum值
    buffer(0) = buffer.getAs[Int](0) + input.getAs[Int](0)
    //更新count值
    buffer(1) = buffer.getAs[Int](1) + 1
  }
  
  /**
   * 合并每个task聚合结果
   * @param buffer1
   * @param buffer2
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //累加sum
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
    //累加count
    buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
  }
  
  /**
   * 计算得到最终结果
   * @param buffer
   * @return
   */
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0).toDouble / buffer.getAs[Int](1)
  }
}


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值