object learn04 {
def main(args: Array[String]): Unit = {
//基本配置
val conf = new SparkConf().setAppName("learn01").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
//创建rdd -> ds
val dataRdd = spark.sparkContext.makeRDD(List(1, 2, 3, 4, 5))
val dataDs= dataRdd.map({
case (age) => {
UserBean(age)
}
}).toDS()
//注册函数并显示列名
val avgFun = new MyAgeAvgClassFunction
val avgColumn = avgFun.toColumn.name("avgFun")
dataDs.select(avgColumn).show()
}
}
case class UserBean(age: BigInt)
case class AvgBuffer(sum: BigInt, count: Int)
/**
* 求平均数avg
* 1)继承Aggregator【输入,缓冲,输出】
* 2)实现方法
*/
class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double] {
//初始化缓冲值得大小
override def zero: AvgBuffer = {
AvgBuffer(0, 0)
}
//内部做计算
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
val c = b.sum + a.age
val d = b.count + 1
AvgBuffer(c, d)
}
//合并分区时计算
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
val a = b1.sum + b2.sum
val b = b1.count + b2.count
AvgBuffer(a, b)
}
//输出
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble / reduction.count
}
//用户自定义就用product,其他用scala提供得
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
Spark-udf自定义函数(强类型)
最新推荐文章于 2024-04-01 09:25:04 发布