SparkSql 3.0 UDAF 求和,求平均

使用spark sql 3.0版本自定义UDFA ,

3.0版本之前 extends  UserDefinedAggregateFunction  已经过时

新方法如下代码:

代码中自定义了求和,求平均

package com.cy.spark

import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, Row, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}


object DemoUDAF {
  def main(args: Array[String]): Unit = {

    //屏蔽日志
    Logger.getLogger("org").setLevel(Level.ERROR)

    val conf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[4]")
    val sc = new SparkContext(conf)

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName)
      .master("local[4]")
      .getOrCreate()

    val rdd: RDD[String] = sc.textFile("E://file/spark/student01.txt")

    val stuRdd: RDD[Stu1] = rdd.map(line => {
      //class01 tom 100
      val split = line.split(" ")
      val classess = split(0)
      val name = split(1)
      val score = split(2).toInt
      Stu1(classess, name, score)
    })
    //重要
    import spark.implicits._
    //rdd -> df
    val df: DataFrame = stuRdd.toDF

    df.createOrReplaceTempView("stu")

    import org.apache.spark.sql.functions._

    //UDAF 求平均
    val avgAgg1 = new Aggregator[Double, (Double, Int), Double] {
      //初始值
      override def zero: (Double, Int) = (0.0, 0)
      //每个分组区局部聚合的方法,
      override def reduce(b: (Double, Int), a: Double): (Double, Int) = {
        (b._1 + a, b._2 + 1)
      }
      //全局聚合调用的方法
      override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = {
        (b1._1 + b2._1, b1._2 + b2._2)
      }
      //计算最终的结果
      override def finish(reduction: (Double, Int)): Double = {
        reduction._1 / reduction._2
      }
      //中间结果的encoder
      override def bufferEncoder: Encoder[(Double, Int)] = {
        Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt);
      }
      //返回结果的encoder
      override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
      }
    }

    //UDAF 求和
    val sumAgg = new Aggregator[Int,Int,Int] {
      //初始值
      override def zero: Int = 0
      //每个分组区局部聚合的方法,
      override def reduce(b: Int, a: Int): Int = b + a
      //全局聚合调用的方法
      override def merge(b1: Int, b2: Int): Int = b1 + b2
      //计算最终的结果
      override def finish(reduction: Int): Int = reduction
      //中间结果的encoder
      override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
      //返回结果的encoder
      override def outputEncoder: Encoder[Int] = Encoders.scalaInt
    }

    //自定义
    spark.udf.register("sum1", udaf(sumAgg))
    val sql =
      """
        |select classess, sum1(score) as score
        |from stu
        |group by classess
        |""".stripMargin

    spark.sql(sql).show()

    spark.stop()
  }
}

case class Stu1(classess:String, name:String, score:Int)

数据源:自己多造点

class01 tom 100

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值