Spark编写UDAF自定义函数

package main.scala

import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer

/**
 * Spark SQL UDAS:user defined aggregation function 
 * UDF: 函数的输入是一条具体的数据记录,实现上讲就是普通的scala函数-只不过需要注册
 * UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作
 */
object SparkSQLUDF {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setMaster("local[*]").setAppName("SparkSQLWindowFunctionOps")
    val sc = new SparkContext(conf)

    val hiveContext = new SQLContext(sc)

    val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink")
    val bigDataRDD = sc.parallelize(bigData)

     val bigDataRowRDD = bigDataRDD.map(line => Row(line))
     val structType = StructType(Array(StructField("name",StringType,true)))
     val bigDataDF = hiveContext.createDataFrame(bigDataRowRDD, structType)

     bigDataDF.registerTempTable("bigDataTable")

    /*
     * 通过HiveContext注册UDF,在scala2.10.x版本UDF函数最多可以接受22个输入参数
     */
     hiveContext.udf.register("computeLength",(input:String) => input.length)
     hiveContext.sql("select name,computeLength(name)  as length from bigDataTable").show

     //while(true){}

     hiveContext.udf.register("wordCount",new MyUDAF)
     hiveContext.sql("select name,wordCount(name) as count,computeLength(name) as length 
     from bigDataTable group by name ").show
  }
}

/**
 * 用户自定义函数
 */
 class MyUDAF extends UserDefinedAggregateFunction
 {
  /**
   * 指定具体的输入数据的类型
   * 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串
   */
  override def inputSchema:StructType = StructType(Array(StructField("name",StringType,true)))

  /**
   * 在进行聚合操作的时候所要处理的数据的中间结果类型
   */
  override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType,true)))

  /**
   * 返回类型
   */
  override def dataType:DataType = IntegerType

  /**
   * whether given the same input,
   * always return the same output
   * true: yes 
   */
  override def deterministic:Boolean = true

  /**
   * Initializes the given aggregation buffer
   */
  override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}

  /**
   * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
   * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
   */
  override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
    buffer(0) = buffer.getInt(0)+1
  }

  /**
   * 最后在分布式节点进行local reduce完成后需要进行全局级别的merge操作
   */
  override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
    buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
  }

  /**
   * 返回UDAF最后的计算结果
   */
  override def evaluate(buffer:Row):Any = buffer.getInt(0)
}

执行结果

16/06/29 19:30:24 INFO DAGScheduler: ResultStage 5 (show at SparkSQLUDF.scala:48) finished in 1.625 s
+------+-----+------+
|  name|count|length|
+------+-----+------+
| Spark|    4|     5|
| Flink|    4|     5|
|Hadoop|    4|     6|
+------+-----+------+

16/06/29 19:30:24 INFO DAGScheduler: Job 3 finished: show at SparkSQLUDF.scala:48, took 1.717878 s

udf与udaf注册语法

 hiveContext.udf.register("computeLength",(input:String) => input.length)

 hiveContext.udf.register("wordCount",new MyUDAF)
package com.spark.sql

import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructField

/**
 * @author Administrator
 */
class NumsAvg extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("nums", DoubleType) :: Nil)

  def bufferSchema: StructType = StructType(
    StructField("cnt", LongType) ::
      StructField("avg", DoubleType) :: Nil)

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0l
    buffer(1) = 0.0
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Long](0) + 1
    buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](0)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
  }

  def evaluate(buffer: Row): Any = {
    val t = buffer.getDouble(1) / buffer.getLong(0)
    f"$t%1.5f".toDouble
  }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值