Spark SQL自定义函数 UDF UDAF

直接上代码

UDF 一进一出

package sparksql_udf

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}

object sparkSqlUDF {
  def main(args: Array[String]): Unit = {
    //创建sparkSession
    //通过sparkSession创建SparkContext
    val spark: SparkSession = SparkSession.builder().master("local[*]").appName("zhiDingSchema").getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")

    //读取数据并加工
    //读取时可以spark.read.textFile("D:\\data\\udf.txt")转化为ds
    val udfRDD: RDD[String] = sc.textFile("D:\\大数据\\学期文档\\spark\\资料\\udf.txt")

    //转化为DS
    import spark.implicits._
    val udfDS: Dataset[String] = udfRDD.toDS()

    //创建临时表
    udfDS.createOrReplaceTempView("udf")

    //注册udf函数
    spark.udf.register("toUpper",(str:String)=>{
      //根据业务需求对数据进行加工
      str.toUpperCase+" 123"
    })

    //sql查询 调用udf函数
    spark.sql("select value,toUpper(value) from udf").show()

    //停止sc、spark
    sc.stop()
    spark.stop()
  }

}

UDAF 多进一出

package sparksql_udf

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object sparkSqlUDAF {

  def main(args: Array[String]): Unit = {
    //创建sparkSession
    //通过sparkSession创建SparkContext
    val spark: SparkSession = SparkSession.builder().master("local[*]").appName("zhiDingSchema").getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")

    //读取数据并加工
    //读取时可以spark.read.textFile("D:\\data\\udf.txt")转化为ds
    val dfJson: DataFrame = spark.read.json("D:\\大数据\\学期文档\\spark\\资料\\udaf.json")

    //注册临时表
    dfJson.createOrReplaceTempView("UDAF")

    //注册UDAF函数
    spark.udf.register("SalaryAvg",new SalaryAvg)

    //计算平均工资的UDAF方法为SalaryAvg
    //查询sql  调用UDAF
    spark.sql("select SalaryAvg(salary)  from UDAF").show()
    //spark.sql("select avg(salary)  from UDAF").show()

    //关闭sc、spark
    sc.stop()
    spark.stop()


  }
  //编写计算平均工资的方法SalaryAvg
  class SalaryAvg extends UserDefinedAggregateFunction {
    //输入的数据类型
    override def inputSchema: StructType = {
      StructType(List(StructField("input", LongType)))
    }

    //缓冲区数据类型 两个
    override def bufferSchema: StructType = {
      StructType(List(StructField("sum", LongType),StructField("total",LongType)))
    }

    //数据返回的类型
    override def dataType: DataType = {
      DoubleType
    }

    //确定是否有相同输出
    override def deterministic: Boolean = {
      true
    }

    /**
     * list(1,2,3,4).reduce(_+_)
     * a=1 b=2
     * a=3 b=3
     */
    //初始化内部数据结构
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      //用于存储总金额
      buffer(0) = 0L
      //用于存储次数
      buffer(1) = 0L
    }

    //更新数据内部结构,区内计算
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      //计算分区总金额
      buffer(0)=buffer.getLong(0)+input.getLong(0)
      //计算分区总次数
      buffer(1)=buffer.getLong(1)+1
    }

    //来自不同分区的数据进行合并,全局合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      //汇聚所有分区的总金额
      buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
      //汇聚所有分区的总次数
      buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
    }

    //计算输出数据值
    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
    }
  }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值