Spark sql自定义函数UTF/UTAF

自定义UTF函数 弱类型

package SparkSql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


object Session_UDAF_1 {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.ERROR)
    val sc: SparkConf = new SparkConf().setMaster("local[*]").setAppName("DataFrame")
    val spark: SparkSession = SparkSession.builder().config(sc).getOrCreate()
    //TODO 导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字
    // (这里并没有用到RDD,DataFrame,DataSet类型转换,但一般建议出现SparkSession对象时,都要加上
    import spark.implicits._
    val frame: DataFrame = spark.read.
      format("csv").
      option("header", "true").option("inferSchema", true.toString).
      load("F:\\MySQL\\python招聘.csv")
    //DataFrame(表结构)对象创建一个临时视图
    frame.createTempView("python")
    //查看frame数据集字段名RecruitPostName,RecruitPostId前5行
    spark.sql("select  RecruitPostName,RecruitPostId from python limit 5").show()

    //创建 UTF函数对象
    val function = new MyAgeFunction()
    //绑定临时函数
    spark.udf.register("avg_RecruitPostId", function)

    //使用自定义UTF函数求字段RecruitPostId的平均值
    spark.sql("select avg_RecruitPostId(RecruitPostId) from python").show()

    spark.stop()
  }

}
//声明用户自定义聚合函数(弱类型)
// 1) 继承UserDefinedAggregateFunction
// 2)实现方法
class MyAgeFunction extends UserDefinedAggregateFunction {

  //函数的输入结构
  override def inputSchema: StructType = {

    new StructType().add("Name", LongType)
  }

  // 计算时函数的的数据结构
  override def bufferSchema: StructType = {
    new StructType().add("sum", LongType).add("count", LongType)
  }

  //函数返回的数据类型
  override def dataType: DataType = DoubleType

  //函数是否稳定
  override def deterministic: Boolean = true

  //计算之间函数 缓冲区的初始化 给sum,count一个初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  //根据查询结果更新缓冲区数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    buffer(0) = buffer.getLong(0) + input.getLong(0) //单节点sum求和
    buffer(1) = buffer.getLong(1) + 1 // 单节点count计数
  }

  //将多个节点 合并缓冲区数据(其实类似分区,分区内计算完毕,分区间聚合分区内的结果)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //多节点sum求和
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) //多节点count计数
  }
  // 将最终计算结果返回
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}

在这里插入图片描述
注意,如果您的文件内容表头(字段名)有特殊字符(空格之类,我使用的是spark2.2.1版本,有出错请务必去除特殊字符,
不过好像2.3版本之后有改进,不知道是不是真的,反正我这块表头有特殊字符,就报错)

自定义UTAF函数 强类型

package SparkSql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}

object Session_UDAF_class {

  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.ERROR)

    val sc: SparkConf = new SparkConf().setMaster("local[*]").setAppName("DataSet")

    val spark: SparkSession = SparkSession.builder().config(sc).getOrCreate() //导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字

    val frame: DataFrame = spark.read.
      format("csv").
      option("header", "true").
      option("inferSchema", true.toString).
      load("E:\\谷歌文件\\archive\\restaurant-1-orders.csv")
    //TODO 导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字
    import spark.implicits._
    //Order_ID    Order_Date Item_Name Quantity Product_Price Total_products
    frame.createTempView("test")   //创建一个临时视图
    spark.sql("select Order_ID,Product_Price from test limit 5").show()

//TODO DataFrame类型转RDD算子  注意RDD算子类型,DataFrame类型转RDD类 类型都将是ROW(行)
    val rdd: RDD[Row] = frame.rdd

    val resultRDD: RDD[Array[Any]] = rdd.map {
      case x => Array(x.getInt(0), x.getString(1), x.getString(2),
        x.getInt(3), x.getDouble(4), x.getInt(5))
    }

    //创建聚合函数对象
    val clazz = new MyFunctionClass()

    //将聚合函数转换为查询列
    val value: TypedColumn[OrderBean, Double] = clazz.toColumn.name("avg_Product_Price")

    // 将Dataframe类型转换为DataSet类型
    val OrderDS: Dataset[OrderBean] = frame.as[OrderBean]

    OrderDS.select(value).show()

    //释放资源
    spark.stop()
  }

}

case class OrderBean(Order_ID: Long, Order_Date: String, Item_Name: String,
                     Quantity: Long, Product_Price: Double, Total_products: Long)

case class AvgBean(var sum: Double, var count: Int)

//声明用户自定义聚合函数(强类型)
// 1) 继承
// 2)实现方法
class MyFunctionClass extends Aggregator[OrderBean, AvgBean, Double] {

  //初始化
  override def zero: AvgBean = {
    AvgBean(0.0, 0)
  }
  //聚合函数,与UserDefinedAggregateFunction的函数update原理相同,聚合函数更好理解而已
  override def reduce(b: AvgBean, a: OrderBean): AvgBean = {
    b.sum = b.sum + a.Product_Price    //计算Product_Price这一列的总和(单节点)
    b.count = b.count + 1
    b
  }
  //与UserDefinedAggregateFunction的merge实现也是 合并多个节点数据
  override def merge(b1: AvgBean, b2: AvgBean): AvgBean = {

    b2.sum = b1.sum + b2.sum
    b2.count = b1.count + b2.count

    b2
  }

  //完成计算
  override def finish(reduction: AvgBean): Double = reduction.sum / reduction.count

  //以下是固定写法
  override def bufferEncoder: Encoder[AvgBean] = Encoders.product //用户自定义类型写法

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble //scala中自带的类型写法
}

在这里插入图片描述

总结: RDD算子难用,将其转换为DataFrame类型(二维表结构),Spark提供sql语句来很轻松的使用DataFrame,相对来说简单许多,至于DataSet是将DataFrame类型中的字段名封装成类(我们一般是使用样例类),类就有属性(字段名),毕竟在面向对象程序语言中,我们还是更希望通过对象来访问数据。

  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值