Spark_udf_udaf

弱类型
package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
//求年龄平均值
//1.需要两个值 累加的年龄和 出现的次数

//user defind aggregate function
object Spark_UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName(“wc”).setMaster(“local[*]”)
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
//创建RDD
//引入隐式转换
import spark.implicits._
//创建聚合函数
val udaf = new MyAgeFunction
//注册聚合函数
spark.udf.register(“avgAge”,udaf)
val frame: DataFrame = spark.read.json(“in/person.json”)
//将df转换为表
frame.createOrReplaceTempView(“stu”)
spark.sql(“select avgAge(age) from stu”).show()
}
}
//继承UserDefinedAggregateFunction
//实现方法
class MyAgeFunction extends UserDefinedAggregateFunction{
//输入的结构是什么样的 传入的是年龄 直接new 增加年龄
override def inputSchema: StructType = {
//输入的字段叫age,longtype类型
new StructType().add(“age”,LongType)
}
//Buffer 计算时的数据结构 缓冲区的数据结构 输入数据的结构
override def bufferSchema: StructType = {
new StructType().add(“sum”,LongType).add(“count”,LongType)
}
//out 函数返回的数据类型 sum/count
override def dataType: DataType = {
DoubleType
}
//函数是否稳定 给什么值就返回什么
override def deterministic: Boolean = true
//函数计算之前缓冲区的初始化 既sum 和count
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//buffer(0)=0L
//buffer(1)=0L //既刚开始把变量count和sum放进缓冲区时的值
buffer.update(0,0L)
buffer.update(1,0L)
}
//更新数据 计算的每一条数据更新缓冲区 节点内更新
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将缓冲区的数据与传入的数据相加
//count每次加1
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//count
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算 sum/count
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}

package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
//求年龄平均值
//1.需要两个值 累加的年龄和 出现的次数

//user defind aggregate function
object Spark_UDAF2 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName(“wc”).setMaster(“local[*]”)
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
//创建RDD
//引入隐式转换
import spark.implicits._
//创建聚合函数
val udaf = new MyAgeFunctionOb
//将聚合函数转换为查询列
val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name(“avgAge”)

val frame: DataFrame = spark.read.json("in/person.json")
val userDS: Dataset[UserBean] = frame.as[UserBean]
userDS.select(avgCol).show()
spark.stop()

}
}
case class UserBean(name:String,age:BigInt)
//处理逻辑类
case class AvgBuffer(var sum:BigInt,var count:Int)
//继承Aggregator
//强类型
// 自定义聚合函数类 继承Aggregator[-IN, BUF, OUT]
class MyAgeFunctionOb extends Aggregator[UserBean,AvgBuffer,Double]{
//初始化 缓冲区(age0,conut0)
override def zero: AvgBuffer = {
AvgBuffer(0,0)
}
//根据输入数据更新缓冲区 返回缓冲区 把输入的对象和冲缓区做操作
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
b.sum=b.sum+a.age
b.count=b.count+1
b
}
//缓冲区的合并操作
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
b1.sum = b1.sum+b2.sum
b1.count=b1.count+b2.count
b1
}
//完成计算
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble/reduction.count
}
//自定义类型的转码 编码解码问题 固定写法 不用管
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值