SparkSql自定义udaf函数

一、弱类型(低版本3.0以下)

        定义类继承 UserDefinedAggregateFunction,并重写其中方法

class MyAveragUDAF extends UserDefinedAggregateFunction {

// 聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(Array(StructField("age",IntegerType)))

// 聚合函数缓冲区中值的数据类型(age,count) def bufferSchema: StructType = {

StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
}
// 函数返回值的数据类型
def dataType: DataType = DoubleType

// 稳定性:对于相同的输入是否一直返回相同的输出。
def deterministic: Boolean = true

// 函数缓冲区初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存年龄的总和
buffer(0) = 0L
// 存年龄的个数
buffer(1) = 0L
}

// 更新缓冲区中的数据
def update(buffer: MutableAggregationBuffer,input: Row): Unit = { if (!input.isNullAt(0))                     {
buffer(0) = buffer.getLong(0) + input.getInt(0) buffer(1) = buffer.getLong(1) + 1
}
}

// 合并缓冲区
def merge(buffer1: MutableAggregationBuffer,buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}

// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}

。。。

//创建聚合函数
var myAverage = new MyAveragUDAF

//在 spark 中注册聚合函数
spark.udf.register("avgAge",myAverage)

二、强类型(3.0后推荐使用)

  • 定义类继承 org.apache.spark.sql.expressions.Aggregator
  • 重写类中的方法

        低版本:采用DSL风格查询

//输入数据类型
case class User01(username:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)

/**
*	定义类继承 org.apache.spark.sql.expressions.Aggregator
*	重写类中的方法
*/
class MyAveragUDAF1 extends Aggregator[User01,AgeBuffer,Double]{ override def zero: AgeBuffer = {
AgeBuffer(0L,0L)
}

override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
b.sum = b.sum + a.age b.count = b.count + 1 b
}

override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = { b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count b1
}

override def finish(buff: AgeBuffer): Double = { buff.sum.toDouble/buff.count
}
//DataSet 默认额编解码器,用于序列化,固定写法
//自定义类型就是 product 自带类型根据类型选择
override def bufferEncoder: Encoder[AgeBuffer] = { Encoders.product
}

override def outputEncoder: Encoder[Double] = { Encoders.scalaDouble
}
}

。。。

//封装为 DataSet
val ds: Dataset[User01] = df.as[User01]

//创建聚合函数
var myAgeUdaf1 = new MyAveragUDAF1
//将聚合函数转换为查询的列
val col: TypedColumn[User01, Double] = myAgeUdaf1.toColumn

ds.select(col).show()

        高版本(3.0):注册函数方式

// TODO 创建 UDAF 函数
val udaf = new MyAvgAgeUDAF
// TODO 注册到 SparkSQL 中
spark.udf.register("avgAge", functions.udaf(udaf))
// TODO 在 SQL 中使用聚合函数

// 定义用户的自定义聚合函数
spark.sql("select avgAge(age) from user").show
// ************************************************** case class Buff( var sum:Long, var cnt:Long )
// totalage, count
class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{ override def zero: Buff = Buff(0,0)

override def reduce(b: Buff, a: Long): Buff = { b.sum += a
b.cnt += 1 b
}
override def merge(b1: Buff, b2: Buff): Buff = { b1.sum += b2.sum
b1.cnt += b2.cnt b1
}

override def finish(reduction: Buff): Double = { reduction.sum.toDouble/reduction.cnt
}
override def bufferEncoder: Encoder[Buff] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
使用SparkSQL和Hive API,可以通过以下步骤实现用户自定义函数(UDF)、聚合函数UDAF)和表生成函数(UDTF): 1. 编自定义函数的代码,例如: ``` // UDF def myUDF(str: String): Int = { str.length } // UDAF class MyUDAF extends UserDefinedAggregateFunction { override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil) override def bufferSchema: StructType = StructType(StructField("count", IntegerType) :: Nil) override def dataType: DataType = IntegerType override def deterministic: Boolean = true override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getInt(0) + input.getString(0).length } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) } override def evaluate(buffer: Row): Any = { buffer.getInt(0) } } // UDTF class MyUDTF extends GenericUDTF { override def initialize(args: Array[ConstantObjectInspector]): StructObjectInspector = { // 初始化代码 } override def process(args: Array[DeferedObject]): Unit = { // 处理代码 } override def close(): Unit = { // 关闭代码 } } ``` 2. 将自定义函数注册到SparkSQL或Hive中,例如: ``` // SparkSQL中注册UDF spark.udf.register("myUDF", myUDF _) // Hive中注册UDF hiveContext.sql("CREATE TEMPORARY FUNCTION myUDF AS 'com.example.MyUDF'") // Hive中注册UDAF hiveContext.sql("CREATE TEMPORARY FUNCTION myUDAF AS 'com.example.MyUDAF'") // Hive中注册UDTF hiveContext.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'com.example.MyUDTF'") ``` 3. 在SQL语句中使用自定义函数,例如: ``` -- 使用SparkSQL中的UDF SELECT myUDF(name) FROM users -- 使用Hive中的UDF SELECT myUDF(name) FROM users -- 使用Hive中的UDAF SELECT myUDAF(name) FROM users GROUP BY age -- 使用Hive中的UDTF SELECT explode(myUDTF(name)) FROM users ``` 以上就是使用SparkSQL和Hive API实现用户自定义函数(UDF、UDAF、UDTF)的步骤。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值