udf和udaf

UDF:计算单词的长度

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object udfDemo {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().appName("udfDemo").master("local[2]").getOrCreate()
    val sc = session.sparkContext
    //构造模拟数据
    val names=Array("Leo","Marry","Jack","Tom")
    val namesRdd = sc.parallelize(names,4)
    val nameRowRdd = namesRdd.map(name=>Row(name))
    val structType = StructType(Array(StructField("name",StringType,true)))
    //创建dataframe,创建是由SparkSession创建的
    val namesDF = session.createDataFrame(nameRowRdd,structType)
    //注册临时表
    namesDF.createTempView("names")
    //定义和注册自定义函数
    //定义函数:自己写匿名函数
    //注册函数:udf.register,注册函数使用的是SparkSession
    session.udf.register("strLen",(str:String)=>str.length)
    //接下来只用的自定义函数
    session.sql("select name,strLen(name) from names").collect().foreach(println)
  }
}

udaf:统计单词出现的次数

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

object udfDemo1 {
  def main(args: Array[String]): Unit = {
    val sp = SparkSession.builder().appName("udfDemo").master("local[*]").getOrCreate()
    val sparkContext = sp.sparkContext
    val rdd0 = sparkContext.parallelize(Array("tom","jerry","mary","mary","tom")).map(name=>Row(name))
    val structType = StructType(Array(StructField("name",StringType,true)))//类型和是否为空
    val namedf = sp.createDataFrame(rdd0,structType)

    //注册表
    namedf.createTempView("v_name")
    //注册自定义的函数
    sp.udf.register("myagg",new UDFA)
    //使用
    val res1= sp.sql("select name,myagg(name) from v_name group by name").collect().foreach(println)
  }
}
class UDFA extends UserDefinedAggregateFunction{
  //指的是,输入数据类型
  override def inputSchema: StructType = {
    StructType(Array(StructField("str",StringType,true)))
  }

  //指的是,中间进行聚合时,所处理数据的类型
  override def bufferSchema: StructType = {
    StructType(Array(StructField("count",IntegerType,true)))
  }
  //最后聚合结果的返回类型--》函数返回值的类型
  override def dataType: DataType = {
    IntegerType
  }
  //数据的一致性,一般为true
  override def deterministic: Boolean = true

  //聚合函数的初始化值,局部聚合使用,为每一个分组的数据执行初始化操作
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
  }
  
  // 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
  // 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getAs[Int](0)+1
  }
  
  // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
  // 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
  }
  
  // 最后返回一个最终的聚合值
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
}

 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值