SparkSQL(5):UDF和UDAF

68 篇文章 0 订阅
16 篇文章 1 订阅

1.二者区别

UDF:用户自定义函数,一输入一输出

UDAF:用户自定义聚合函数,多输入一输出

2.实现代码

(1)UDAF代码:

package _0728sql

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * 
  */
object Avg_UDAF extends UserDefinedAggregateFunction{
  //(a)
  override def inputSchema: StructType = {
    /**
      * 给定UDAF函数的输入参数类型(schema)
      * iv代表input value
      */
    StructType(Array(
      StructField("iv",DoubleType)
    ))
  }
  //(b)
  override def bufferSchema: StructType = {
    //给定缓存数据的数据类型 avg = totalValue / totalCount
    //tv:total value
    //tc:total count
    StructType(Array(
      StructField("tv",DoubleType),
      StructField("tc",IntegerType)
    ))
  }
  //(c)
  override def dataType: DataType = {
    //给定返回的数据类型
    DoubleType
  }
  //(d)
  override def deterministic: Boolean = {
    //给定多次运行是否允许返回结果不一致(模糊查询) true表示不允许
    //一般都为true
    true
  }
  //(e)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //缓存数据的初始值
    buffer.update(0,0.0)
    buffer.update(1,0)

  }
  //(f)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //对于每一条输入数据(当前分组的),更新buffer中的值
    //1、获取输入数据
    val  iv = input.getDouble(0)
    //2、获取缓存区数据
    val tv = buffer.getDouble(0)
    val tc = buffer.getInt(1)
    //3、更新缓存区数据
    buffer.update(0,tv + iv)
    buffer.update(1,tc + 1)
  }
  //(g)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //当两个分区的结果需要进行合并的时候,会调用该merge方法
    //1、获取buffer1的数据
    val tv1 = buffer1.getDouble(0)
    val tc1 = buffer1.getInt(1)
    //2、获取buffer2的数据
    val tv2 = buffer2.getDouble(0)
    val tc2 = buffer2.getInt(1)
    /*
     3、然后把数据更新到buffer1当中去,不能更新到buffer2
       因为MutableAggregationBuffer这个数据类型才是可以更新的数据类型实现了update方法
     */
    buffer1.update(0,tv1+tv2)
    buffer1.update(1,tc1+tc2)
  }
  //(h)
  override def evaluate(buffer: Row): Any = {
    val tv = buffer.getDouble(0)
    val tc = buffer.getInt(1)
    tv/tc
  }
}

(2)总代码

package _0728sql

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
//import _0728sql.Avg_UDAF
/**
  *
  */
object UDFandUDAF extends App{

  /**
    *
    */
  val conf = new SparkConf()
    .setMaster("local[*]")
    .setAppName("UDFandUDAF")
  //这个方法是一个锁的机制,通过这个方法可以保证只有一个上下文
  val sc = SparkContext.getOrCreate(conf)
  //如果不需要用hive就不要用hivecontext,使用sqlcontext就可以了
  val sqlContext = new SQLContext(sc)

  //1.UDF
// udf 保留小数点后两位
// format_double是函数名称,后面是匿名函数
  sqlContext.udf.register("format_double",(value:Double)=>{
    import java.math.BigDecimal
    val bd=new BigDecimal(value)
    bd.setScale(2,BigDecimal.ROUND_HALF_UP).doubleValue()
  })

import sqlContext.implicits._
  sc.parallelize(Array(
    (1, 1234),
    (1, 45212),
    (1, 22125),
    (1, 12521),
    (1, 12352),
    (2, 52352),
    (2, 2232),
    (2, 12521),
    (2, 12323),
    (3, 2253),
    (3, 2233),
    (3, 22558),
    (4, 252),
    (4, 235),
    (5, 523)
  )).toDF("id", "sal").registerTempTable("tmp_emp")

  sqlContext.sql(
    """
      			|select
      			|id,AVG(sal) as sal1,
      			|format_double(AVG(sal)) as sal3
      			|from tmp_emp
      			|group by id
    		""".stripMargin).show

  //2.UDAF
  sqlContext.udf.register("self_avg",Avg_UDAF)
  sqlContext.sql(
    """
      			|select
      			|id,AVG(sal) as sal1,
      			|format_double(AVG(sal)) as sal3,
            |format_double(self_avg(sal)) as sal4
      			|from tmp_emp
      			|group by id
    		""".stripMargin).show

}

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值