Spark SQL 中的UDF、UDAF、UDTF

一、UDF

UDF(User-defined functions)用户自定义函数,简单说就是输入一行输出一行的自定义算子。(一对一)
数据文件:hobbies.txt,第一列为姓名,其他为兴趣爱好

alice,jogging&Coding&cooking
lina,traveldance&cooking

自定义UDF,实现的是计算每个人的兴趣爱好个数

// 样例类
case class Hobbies(name:String,hobbies:String)
object UDFDemo {
  def main(args: Array[String]): Unit = {
  	// 获取SparkSession对象
    val spark: SparkSession = SparkSession.builder()
      .appName("udfDemo")
      .master("local[*]")
      .getOrCreate()
	// 获取SparkContext对象
    val sc:SparkContext = spark.sparkContext
    import spark.implicits._
	// 读取文件
    val rdd1: RDD[String] = sc.textFile("in/hobbies.txt")
    // 将姓名与爱好以逗号分隔,创建成样例类后转成DataFrame
    val df: DataFrame = rdd1.map(_.split(","))
      .map(x => Hobbies(x(0), x(1)))
      .toDF()
	// 注册临时表
    df.registerTempTable("hobbies")
    // 注册udf,名字为hoby_num,功能用匿名函数代替
    spark.udf.register("hoby_num",
      (x:String)=>x.split("&").length)
	// 查询
    val frame: DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) from hobbies")
    frame.show()
  }
}

在这里插入图片描述

二、UDAF

UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。(多对一)
数据文件:user.json

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

自定UDAF,实现计算平均年龄

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("udafDemo")
      .getOrCreate()

    val df: DataFrame = spark.read.json("in/user.json")
    spark.udf.register("ageAvg",new MyAgeAvgFunction)
	// 创建临时视图
    df.createTempView("userInfo")
	// 查询
    val frame: DataFrame = spark.sql("select sex,ageAvg(age) from userInfo group by sex")
    frame.printSchema()
    frame.show()
  }

}
// AgeAvgFunction继承UserDefinedAggregateFunction,需要重写8个方法
class AgeAvgFunction extends UserDefinedAggregateFunction{

  // 指定聚合函数的输入数据类型
  override def inputSchema: StructType = {
  	// age为要聚合的列,LongType为类型
    new StructType().add("age",LongType)
    //也可以写成下面这种形式
    // StructType(StructField("age",LongType)::Nil)
  }

  // 指定缓冲区的数据结构
  override def bufferSchema: StructType = {
  	// 此处可以这样理解:缓冲区会保存两个数据,sum是用来记录年龄总和,count是用来记录总人数
    new StructType().add("sum",LongType).add("count",LongType)
    //    StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
  }
  // 指定集合函数输出数据的类型
  override def dataType: DataType = DoubleType
  // 聚合函数是否是幂相等,即相同输入数据是否总能得到相同输出数据
  override def deterministic: Boolean = true
  // 初始化缓冲区的初始值,可根据需要自行设定
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0L
    buffer(1)=0L
  }
  // 可理解为单个buffer内部的计算,即一条数据传递到一个buffer内后,它需要把之前的年龄与此条数据年龄相加,然后数量加1
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }
  // 合并多个buffer计算的结果,类似不同分区结果合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  // 计算最终结果,总年龄/总人数
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

三、UDTF

UDTF(User-Defined Table-Generating Functions),用户自定义生成函数。它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”。(一对多)

object SparkUDTFDemo {
  def main(args: Array[String]): Unit = {
    // 创建SparkSession
    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("sqlDemo")
      .enableHiveSupport()  // //启用hive
      .getOrCreate
    val sc: SparkContext = spark.sparkContext
    import spark.implicits._
    
    val lines: RDD[String] = sc.textFile("in/udtf.txt")
    // 将数据处理并转换成DataFrame
    val stuDF: DataFrame = lines.map(_.split(","))
      .filter(x => x(1).equals("ls"))
      .map(x => (x(0), x(1), x(2)))
      .toDF("id", "name", "class")
    // 创建或替换临时视图
    stuDF.createOrReplaceTempView("student")
    // 这里需要注意,如果编写的UDTF类有包名,as 后面需要将表名写上
    spark.sql("create temporary function myFunc as 'sql.myUDTF'")
    // 在spark sql 中使用UDTF查询
    val resultDF: DataFrame = spark.sql("select myFunc(class) from student")
    // 查看结果
    resultDF.printSchema()
    resultDF.show()
  }
}

// 继承GenericUDTF类
class myUDTF extends GenericUDTF{
  // 该函数的作用:①输入参数校验,只能传递一个参数 ②指定输出的字段名和字段类型
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    // 只能有一个参数,若多于1个,则抛异常
    if(argOIs.length!=1){
      throw new UDFArgumentException("只能传递一个参数")
    }
    // 用于验证参数的类型
    if(argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
      throw new UDFArgumentException("参数类型不匹配")
    }
    
    //初始化表结构
    //创建数组列表存储表字段
    val fieldNames = new util.ArrayList[String]()
    val fieldsOIs = new util.ArrayList[ObjectInspector]()
    // 输出字段的名称
    fieldNames.add("hobbies")
    // 这里定义的是输出列字段类型
    fieldsOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    //将表结构两部分聚合在一起
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldsOIs)
  }

  // 用于处理数据,入参数组objects里只有1行数据,即每次调用process方法只处理一行数据
  override def process(objects: Array[AnyRef]): Unit = {
    // 将字符串切分为单个字符的数组
    val strings: Array[String] = objects(0).toString.split(" ")
    for (elem <- strings) {
      val tmp = new Array[String](1)
      tmp(0) = elem
      //forward必须传入字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {}
}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值