SparkSQL代码笔记02——UDF、UDAF

一、UDF

 

package com.zgm.sc.day14

import org.apache.spark.sql.SparkSession

/**
  * 用udf实现字符串拼接
  */
object UDFDemo1 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("UDFDemo1")
      .master("local")
      .getOrCreate()

    val df = spark.read.json("dir/people.json")

    // 注册函数,注册后正在整个应用中都可以用
    spark.udf.register("newname", (x: String) => "name:" + x)

    df.createOrReplaceTempView("person")
    spark.sql("select newname(name) as new_name  from person").show()

    spark.stop()
  }
}


//运行结果:

+------------+
|    new_name|
+------------+
|name:Michael|
|   name:Andy|
| name:Justin|
+------------+

 

二、UDAF

用户自定义聚合函数

 

 

1、UDAF函数支持DataFrame(弱类型)

过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。

ps:弱类型指的是在编译阶段是无法确定数据类型的,而是在运行阶段才能创建类型

package com.qf.gp1921.day13

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}

/**
  * 使用UDAF操作DataFrame
  * 需求:用UDAF统计员工的平均薪资
  */
class UDAFDemo1 extends UserDefinedAggregateFunction {
  // 指定输入类型
  override def inputSchema: StructType = StructType(Array(StructField("salary", DoubleType, true)))
  // 缓冲的作用是将上次的结果和这次传进来的结果进行聚合,指定缓冲(分区)类型和聚合过程
  override def bufferSchema: StructType =
    StructType(StructField("sum", DoubleType) :: StructField("count", DoubleType) :: Nil)

  // 返回类型
  override def dataType: DataType = DoubleType

  // 如果给true,有相同的输入,该函数就有相同的输出
  // 如果输入的数据有不同的情况,比如每次数据有不同的时间或有不同的数据对应的offset,
  // 这时候得到的结果可能就不一样,这个值就设置为false
  override def deterministic: Boolean = true

  // 初始化方法,对buffer中的数据进行初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 员工的总薪资
    buffer(0) = 0.0
    // 员工的人数
    buffer(1) = 0.0
  }

  // 局部聚合,分区内的聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // buffer是指之前的结果,input是这次传进来的数据,将要和buffer进行聚合
    if (!input.isNullAt(0)) {
      // 聚合薪资
      buffer(0) = buffer.getDouble(0) + input.getDouble(0)
      // 聚合人数
      buffer(1) = buffer.getDouble(1) + 1
    }
  }

  // 全局聚合,分区和分区的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 合并薪资
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    // 合并人数
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  // 最终的结果,可以在该方法中对结果进行再次处理
  override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getDouble(1)
}

object UDAFDemo1 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("UDFDemo1")
      .master("local")
      .getOrCreate()
    spark.udf.register("aggrfunc", new UDAFDemo1)

    val df = spark.read.json("dir/employees.json")
    df.createOrReplaceTempView("employees")

    df.show()

    val res = spark.sql("select aggrfunc(salary) as avgsalary from employees")
    res.show()

    spark.stop()
  }
}


//运行结果
+------+
|avgsalary|
+------+
|3750.0|
+------+

2.UDAF函数支持DataSet(强类型)

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资

ps:在编译阶段就确定了数据类型

package com.qf.gp1921.day13

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}

/**
  * udaf操作DataSet
  */
// 生成DataSet的类型
case class Employee(name: String, salary: Double)
// 作为缓冲的类型
case class AvgSalary(var sum: Double, var count: Double)

class UDAFDemo2 extends Aggregator[Employee, AvgSalary, Double]{
  // 初始化方法,初始化每个buffer
  override def zero: AvgSalary = AvgSalary(0.0, 0.0)
  // 局部聚合
  override def reduce(buffer: AvgSalary, employee: Employee): AvgSalary = {
    buffer.sum += employee.salary // 聚合薪资
    buffer.count += 1 // 聚合人数
    buffer
  }
  // 全局聚合
  override def merge(b1: AvgSalary, b2: AvgSalary): AvgSalary = {
    b1.sum += b2.sum // 分区和分区的聚合,聚合薪资
    b1.count += b2.count // 聚合人数
    b1
  }
  // 计算结果
  override def finish(reduction: AvgSalary): Double = reduction.sum / reduction.count

  // 设置中间值的编码, 用的编码和Tuple和case是一样的
  override def bufferEncoder: Encoder[AvgSalary] = Encoders.product
  // 设置最终结果的编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

object UDAFDemo2 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("UDFDemo2")
      .master("local")
      .getOrCreate()

    import spark.implicits._

    val ds: Dataset[Employee] = spark.read.json("dir/employees.json").as[Employee]

    ds.show

    // 指定某个列并调用udaf
    val avgsalary: TypedColumn[Employee, Double] = new UDAFDemo2().toColumn.name("avg_salary")
    val res: Dataset[Double] = ds.select(avgsalary)
    res.show

    spark.stop()
  }
}



//运行结果
+------+
|avg_salary|
+------+
|3750.0|
+------+




 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值