SparkSQL的UDF和UDAF

1.UDF

注:以下的SparkSQL初始化方式不是最新的,请参考上篇博客进行修改

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext

object UDFTest {
  def main(args: Array[String]): Unit = {
     val conf = new SparkConf().setAppName("udf").setMaster("local")
     val sc = new SparkContext(conf)

    val hiveSQLContext = new HiveContext(sc)

    hiveSQLContext.udf.register("toUpper",name =>{
      if (name!=null){
        name.toString.toUpperCase
      }else{
        " "
      }
    })

    hiveSQLContext.udf.register("strLength",name=>{
      if(name!=null){
        name.toString.length
      }else{
        0
      }

    })

    hiveSQLContext.sql("select toUpper(name) from student")
    hiveSQLContext.sql("select strLength(name) from student")

  }
}
2.UDAF

(1)1.6.0版本

package lesson02

import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types._
/**
  *
  * 需求:
  *     需要求这家公司的所有员工的平均工资
  * 思路:
  *    1)先求出所有员工的工资 countSalary
  *    2)求出员工的总数   count
  *    3) 平均工资countSalary / count
  */
object UDAFTest extends UserDefinedAggregateFunction{

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("udf").setMaster("local")
    val sc = new SparkContext(conf)
    val hiveSQLContext = new HiveContext(sc)
    hiveSQLContext.udf.register("avg_salary",UDAFTest)
    hiveSQLContext.sql("select avg_salary(salary) from worker")
  }
   //定义输入的数据类型
  override def inputSchema: StructType = StructType(
    StructField("salary",DoubleType,true)::Nil
  )
  //定义输出的数据类型
  override def dataType: DataType = DoubleType
  /*
    * 一般我们要完成聚合函数的功能,需要一些中间变量来帮忙完成。
    * 然后可以在这儿去定时临时的缓存变量
    * 根据我们的分析需要定义两个;
    * countSalay: 用来记录所有员工的总工资
    * count:用来统计总人数的
    */
  override def bufferSchema: StructType = StructType{
    StructField("countSalary",DoubleType,true)::
      StructField("count",IntegerType,true)::Nil
  }
  //给参与计算的中间变量赋初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0,0.0)
    buffer(1,0)
  }

  
   /* * 修改 中间的结果值
    * @param buffer  上一次
    * @param input  这次
    * */
    
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val countSalary = buffer.getDouble(0)
    val count = buffer.getInt(1)
    val salary = input.getDouble(0)
    buffer(0,salary+countSalary)
    buffer(1,1+count)
  }
  // 全部汇总
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val countSalary1 = buffer1.getDouble(0)
    val count1 = buffer1.getInt(1)
    val countSalary2 = buffer2.getDouble(0)
    val count2 = buffer2.getInt(1)
    buffer1(0,countSalary1+countSalary2)
    buffer1(1,count1+count2)
  }

  //获取最后的结果值
  override def evaluate(buffer: Row): Any = {
    val countSalary = buffer.getDouble(0)
    val count = buffer.getInt(1)
    countSalary / count
  }


  override def deterministic: Boolean = true
}



(2)2.2.0版本

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

object MyAverage extends UserDefinedAggregateFunction {
  // Data types of input arguments of this aggregate function
  def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
  // Data types of values in the aggregation buffer
  def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }
  // The data type of the returned value
  def dataType: DataType = DoubleType
  // Whether this function always returns the same output on the identical input
  def deterministic: Boolean = true
  // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
  // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
  // the opportunity to update its values. Note that arrays and maps inside the buffer are still
  // immutable.
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  // Updates the given aggregation buffer `buffer` with new input data from `input`
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      buffer(0) = buffer.getLong(0) + input.getLong(0)
      buffer(1) = buffer.getLong(1) + 1
    }
  }
  // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  // Calculates the final result
  def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}

// Register the function to access it
spark.udf.register("myAverage", MyAverage)

val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值