【极简spark教程】spark聚合函数

聚合函数分为两类,一种是spark内置的常用聚合函数,一种是用户自定义聚合函数

UDAF

不带类型的UDAF【较常用】

  1. 继承UserDefinedAggregateFunction
  2. 定义输入数据的schema
  3. 定义缓存的数据结构
  4. 聚合函数返回值的数据类型
  5. 定义聚合函数的幂等性,一般为true
  6. 初始化缓存
  7. 更新缓存
  8. 合并缓存
  9. 计算结果

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

object avg extends UserDefinedAggregateFunction {
  // 定义输入数据的schema,需要指定列名,但在实际使用中这里指定的列名没有意义
  override def inputSchema: StructType = StructType(List(StructField("input", LongType)))
  // 缓存的数据结构,bufferSchema定义了缓存的数据结构具有sum和count两个字段
  override def bufferSchema: StructType = StructType(List(StructField("sum", LongType), StructField("count", LongType)))
  // 聚合函数返回值的数据类型:返回值的类型必需与下面的evaluate返回类型一致
  override def dataType: DataType = LongType
  // 聚合函数的幂等性,相同输入总是能得到相同输出
  override def deterministic: Boolean = true
  // 初始化缓存:根据bufferSchema,缓存具有sum和count两个字段,这里会对sum和count两个变量的值进行初始化
  // tips:缓存buffer是MutableAggregationBuffer类型,你可以简单理解buffer就是一个数组
  // tips:在这里buffer是具有代表了sum和count数值的二元数组
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  // 更新缓存:接受并处理输入数据,更新buffer
  // tips:在实际处理中,输入数据是DataFrame,DataFrame是由多个Row组成的,每个Row会逐个传递给update,更新buffer中的值
  // tips:必须对输入的input进行检查,防止input.getLong(i)出现越界报错ArrayIndexOutOfBoundsException
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(input.isNullAt(0)) return
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }
  // 合并缓存:对多个buffer进行合并,这里的合并方式类似于reduce,新来的buffer都会和左侧合并后的大buffer进行合并,合并后保留大buffer的值,buffer2会被丢弃
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  // 计算结果:根据所有buffer合并后的值,计算最终的结果
  // tips:这里所有buffer合并后对值为整体的sum和count,计算整体的sum和count比值,我们得到最终的平均值
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0) / buffer.getLong(1)
  }


}

不带类型的UDAF的使用

  1. 在sparkSQL中使用UDAF

  2. 在DataFrame中使用UDAF

def main(args: Array[String]): Unit = {
  val spark = SparkSession.builder().master("local").getOrCreate()
  // 注册UDAF函数,和UDF函数一样
  spark.udf.register("my_avg", avg)

  // test.txt文件内容
  // score|user
  // 90|Tom
  // 95|Jerry
  // 100|Claris
  // sparkSQL读取文件,创建视图
  
  // sparkSQL的第一步:读取文件并创建视图
  spark.read.option("header","true").option("sep","|").csv("test.txt").createOrReplaceTempView("v_user")
  // sparkSQL的第二步:在spark.sql中调用UDAF,求分数的均值
  spark.sql("select u_avg(score) as avg_score from v_user").show()
  
  // DataFrame的第一步:读取文件,创建DataFrame
  val df1 = spark.read.option("header","true").option("sep","|").csv("data/other/test.txt")
  // DataFrame的第二步:在df.agg中,使用callUDF调用UDAF函数,求分数的均值
  val df2 = df1.agg(callUDF("my_avg",col("score")))
  df2.show(false)

  }

带类型的UDAF【不常用】

  1. 继承Aggregator,继承时须在方括号内指定输入类型、缓存类型、输出类型
  2. 定义作为输入类型的User,作为缓存类型的Average,返回类型为Double
  3. 初始化缓存
  4. 更新缓存
  5. 合并缓存
  6. 计算结果
  7. 固定操作:定义缓存编码器(一般都是Encoders.product)、输出编码器

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.{Aggregator, Window}

case class Average(var sum: Long, var count: Long)
case class User(score: String, name:String)

// 继承Aggregator需要指定输入类型User、缓存类型Average、输出类型Double
object avg1 extends Aggregator[User,Average,Double]{

  // 初始化缓存:这里的缓存为一个Average实例,第一个0L代表sum,第二个0L代表count
  override def zero: Average = Average(0L, 0L)

  // 更新缓存:接受一个User类型,解析出需要的字段,进行累积计算
  override def reduce(b: Average, a: User): Average = {
    b.sum += a.score.toLong
    b.count += 1L
    b
  }

  // 合并缓存:对多个缓存(Average对象)进行合并,所有右侧的Average会逐个合并到最左侧的Average,返回左侧的Average
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 计算结果:根据合并后的结果计算最终结果
  override def finish(reduction: Average): Double = {
    reduction.sum.toDouble / reduction.count.toDouble
  }

  // 缓存编码器:注意左侧返回类型为Encoder[Average],只要是自定义类型,右侧一般都是Encoders.product
  override def bufferEncoder: Encoder[Average] = Encoders.product
  // 输出编码器:对输出进行编码,编码为java兼容的Double类型
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

带类型的UDAF的使用

  1. 在dataSet中结合select使用UDAF
def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local").getOrCreate()

    // test.txt文件内容
    // score|user
    // 90|Tom
    // 95|Jerry
    // 100|Claris

    // DataSet的第一步:导入隐式转换,否则读取文件并调用as[U]时会报错
    import spark.implicits._
    // DataSet的第二步:读取文件,创建DataSet,这里由于读取的是csv文件,score字段默认为字符串类型,与User样例类中的类型保持一致,否则会报错String cannot cast to int
    val df1 = spark.read.option("header","true").option("sep","|").csv("data/other/test.txt").as[User]
    df1.show(false)
    // DataFrame的第二步:在df.select中调用UDAF,求分数的均值
    val df2 = df1.select(avg1.toColumn.name("test"))
    df2.show(false)
  }
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鱼摆摆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值