SparkSQL的内置函数与自定义函数

SparkSQL内置函数
sparkSQL支持大量的常见函数,具体参考此文章https://www.iteblog.com/archives/2336.html

SparkSQL的自定义函数(UDF)
在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

•UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
•UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
•UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

自定义一个UDF函数需要继承UserDefinedAggregateFunction类,并实现其中的8个方法,然后注册临时函数:(不推荐)

package com.zhbr.process

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructType}

object UdfTest extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = ???

  override def bufferSchema: StructType = ???

  override def dataType: DataType = ???

  override def deterministic: Boolean = ???

  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  override def evaluate(buffer: Row): Any = ???
}
spark.udf.register("toDate",UdfTest)
spark.sql("select toDate(pp) from frame").show()

每个方法代表的含义是:

inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果

推荐方式:因为更简单

spark.udf.register("toDate", (date: Timestamp) => {
      val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
      val strTime = sdf.format(date)
      val splits = strTime.split(" ")
      val yearMothDay = splits(0)
      val sdf2 = new SimpleDateFormat("yyyy-MM-dd")
      val date1: Date = sdf2.parse(yearMothDay)
      new java.sql.Date(date1.getTime)
    })
 spark.sql("select toDate(pp) from frame").show()

拓展SparkSQL的自定义函数(UDAF)

package com.zhbr.process
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
class CustomerAvg extends UserDefinedAggregateFunction {
  //输入的类型
  override def inputSchema: StructType = StructType(StructField("salary", LongType) :: Nil)
  //缓存数据的类型
  override def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }
  //返回值类型
  override def dataType: DataType = LongType
  //幂等性
  override def deterministic: Boolean = true
  //初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
//更新 分区内操作
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getLong(0) +input.getLong(0)
    buffer(1)=buffer.getLong(1)+1L
  }
//合并 分区与分区之间操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }
  //最终执行的方法
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}
spark.udf.register("MyAvg",CustomerAvg)
spark.sql("select MyAvg(age) avg_age from frame").show()

拓展SparkSQL自定义函数(UDTF)

class UserDefinedUDTF extends GenericUDTF{

  //这个方法的作用:1.输入参数校验  2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
  override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
    if (args.length != 1) {
      throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
    }
    if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
    }

    val fieldNames = new util.ArrayList[String]
    val fieldOIs = new util.ArrayList[ObjectInspector]

    //这里定义的是输出列默认字段名称
    fieldNames.add("col1")
    //这里定义的是输出列字段类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
  }

  //这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
  override def process(args: Array[AnyRef]): Unit = {
    //将字符串切分成单个字符的数组
    val strLst = args(0).toString.split("")
    for(i <- strLst){
      var tmp:Array[String] = new Array[String](1)
      tmp(0) = i
      //调用forward方法,必须传字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {}
}

关于SparkSQL自定义函数的小结:参考https://blog.csdn.net/laksdbaksjfgba/article/details/87162906#3_UDTF_218

关于UDF
简单粗暴的理解,它就是输入一行输出一行的自定义算子
我们可以通过实名函数或匿名函数的方式来实现,并使用sparkSession.udf.register()注册
需要注意,截至目前(spark2.4)最多只支持22个输入参数的UDF

关于UDAF
简单粗暴的理解,它就是输入多行输出一行的自定义算子,比UDF的功能强大一些
通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来实现UDAF算子,并使用sparkSession.udf.register()注册

关于UDTF
简单粗暴的理解,它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”
通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来实现UDTF算子,但是似乎无法使用sparkSession.udf.register()注册。注册方法如下:
sparkSession.sql(“CREATE TEMPORARY FUNCTION 自定义算子名称 as ‘算子实现类全限定名称’”)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值