SparkSQL 使用UDAF实现自定义聚合函数

5 篇文章 0 订阅
4 篇文章 0 订阅

一、介绍

Spark SQL中自定义函数包括UDF和UDAF

自定义函数

          UDF:一进一出

          UDAF:多进一出

二、UDAF函数

UDAF:User Defined Aggregate Function。用户自定义聚合函数。是Spark 1.5.x引入的最新特性。 *

UDF:其实更多的是针对单行输入,返回一个输出 * 这里的UDAF,则可以针对多行输入,进行聚合计算,返回一个输出,功能更加强大。

首先创建class继承UserDefinedAggregateFunction类

package SparkSQL

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class StringCount extends UserDefinedAggregateFunction{
  //指的是输入数据的类型
  override def inputSchema: StructType = {
    StructType(Array(StructField("str",StringType,true)))
  }
  //bufferSchema指的是中间进行聚合时,所处理的数据类型
  override def bufferSchema: StructType = {
    StructType(Array(StructField("count",IntegerType,true)))
  }
  //dataType指的是函数返回值的类型
  override def dataType: DataType = {
    IntegerType
  }

  override def deterministic: Boolean = {
    true
  }
  //为每个分组的数据执行初始化操作
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
  }
  //指的是,每个分组有新的值进来的时候,如何进行分组对应的聚合值的计算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getAs[Int](0)+1
  }
  //由于spark是分布式的。所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
  //但是,最后一个分组会在节点上聚合值,要进行merge,也就是合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
  }
  //最后指的是,一个分组聚合之,如何通过中间的缓存聚合之,最后返回一个最终的聚合值
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
}

 接下来进行测试

package SparkSQL

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

object UDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("UDF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    //模拟构造数据
    val names = Array("leo","Marry","Jack","Tom","Tom","Tom","leo","leo")
    val nameRDD=sc.parallelize(names,5)
    val namesRowRDD=nameRDD.map{name=>Row(name)}
    val structType = StructType(Array(StructField("name",StringType,true)))
    val namesDF=sqlContext.createDataFrame(namesRowRDD,structType)
    //注册一张零时表
    namesDF.registerTempTable("names")
    //定义和注册自定义函数
    sqlContext.udf.register("strCount",new StringCount)
    //使用自定义函数
    sqlContext.sql("select name,strCount(name) as a from names group by name order by a desc").collect().foreach(println)
//  sqlContext.sql("select name,count(*) from names group by name").collect().foreach(println)
  }
} 

运行结果:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值