spark1.5 自定义聚合函数UDAF

自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义,加了一点注释:

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.expressions

import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental

/**
 * :: Experimental ::
 * The base class for implementing user-defined aggregate functions (UDAF).
 */
@Experimental
abstract class UserDefinedAggregateFunction extends Serializable {

  /**
   * A [[StructType]] represents data types of input arguments of this aggregate function.
   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this [[StructType]] is only used to identify the corresponding
   * input argument. Users can choose names to identify the input arguments.
   */
   //输入参数的数据类型定义
  def inputSchema: StructType

  /**
   * A [[StructType]] represents data types of values in the aggregation buffer.
   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
   * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
   * the returned [[StructType]] will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this [[StructType]] is only used to identify the corresponding
   * buffer value. Users can choose names to identify the input arguments.
   */
   //聚合的中间过程中产生的数据的数据类型定义
  def bufferSchema: StructType

  /**
   * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
   */
   //聚合结果的数据类型定义
  def dataType: DataType

  /**
   * Returns true if this function is deterministic, i.e. given the same input,
   * always return the same output.
   */
   //一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
  def deterministic: Boolean

  /**
   * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
   *
   * The contract should be that applying the merge function on two initial buffers should just
   * return the initial buffer itself, i.e.
   * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
   */
   //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
   //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
  def initialize(buffer: MutableAggregationBuffer): Unit

  /**
   * Updates the given aggregation buffer `buffer` with new input data from `input`.
   *
   * This is called once per input row.
   */
   //用输入数据input更新buffer值,类似于combineByKey
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

  /**
   * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
   *
   * This is called when we merge two partially aggregated data together.
   */
   //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
   //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

  /**
   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
   * aggregation buffer.
   */
   //计算并返回最终的聚合结果
  def evaluate(buffer: Row): Any

  /**
   * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
   */
   //所有输入数据进行聚合
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression2(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = false)
    Column(aggregateExpression)
  }

  /**
   * Creates a [[Column]] for this UDAF using the distinct values of the given
   * [[Column]]s as input arguments.
   */
   //所有输入数据去重后进行聚合
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
    val aggregateExpression =
      AggregateExpression2(
        ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = true)
    Column(aggregateExpression)
  }
}

/**
 * :: Experimental ::
 * A [[Row]] representing an mutable aggregation buffer.
 *
 * This is not meant to be extended outside of Spark.
 */
@Experimental
abstract class MutableAggregationBuffer extends Row {

  /** Update the ith value of this buffer. */
  def update(i: Int, value: Any): Unit
}

下面我们自己实现一个求平均数的聚合函数:

class MyAvg extends UserDefinedAggregateFunction {
    override def inputSchema: StructType = {
      new StructType().add("myinput", DoubleType)
    }

    override def bufferSchema: StructType = {
      new StructType().add("mycnt", LongType).add("mysum", DoubleType)
    }


    override def dataType: DataType = DoubleType


    override def deterministic: Boolean = true


    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer.update(0, 0l)
      buffer.update(1, 0d)
    }
    //partitions内部combine
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0, buffer.getAs[Long](0) + 1) //条数加1
      buffer.update(1, buffer.getAs[Double](1) + input.getAs[Double](0)) //输入汇总
      //目前1.5版本好像还有点问题,不能通过字段名来取值
      //      buffer.update(0, buffer.getAs[Long]("mycnt") + 1) //条数加1
      //      buffer.update(1, buffer.getAs[Double]("mysum") + input.getAs[Double]("myinput")) //输入汇总
    }

    //MutableAggregationBuffer继承自Row
    //partitions间合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getAs[Long](0) + buffer2.getAs[Long](0))
      buffer1.update(1, buffer1.getAs[Double](1) + buffer2.getAs[Double](1))
      //目前1.5版本好像还有点问题,不能通过字段名来取值
      //      buffer1.update(0, buffer1.getAs[Long]("mycnt") + buffer2.getAs[Long]("mycnt"))
      //      buffer1.update(1, buffer1.getAs[Double]("mysum") + buffer2.getAs[Double]("mysum"))
    }


    override def evaluate(buffer: Row): Any = {
      //计算平均值
      val avg = buffer.getAs[Double](1) / buffer.getAs[Long](0)
      //目前1.5版本好像还有点问题,不能通过字段名来取值
      //      val avg = buffer.getAs[Double]("mysum") / buffer.getAs[Long]("mycnt")
      f"$avg%.2f".toDouble
    }


  }
自定义聚合函数需要实现以上抽象类的这8个方法。

下面我们写一个测试自定义UDAF的测试类:

import org.apache.log4j.{Logger, Level}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.types.{StringType, DoubleType, StructField, StructType}

/**
 * @author Administrator
 */
object AvgTest {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.ERROR)
    val conf = new SparkConf().setAppName("UDAF TEST").setMaster("local")

    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)


    val nums = Seq(("a", 1.1), ("a", 2.1), ("b", 1.1))
    val numsRDD = sc.parallelize(nums, 2)

    val numsRowRDD = numsRDD.map { x => Row(x._1, x._2) }

    val schema = new StructType().add(StructField("id", StringType, nullable = false)).add(StructField("num", DoubleType, nullable = true))

    val numsDF = sqlContext.createDataFrame(numsRowRDD, schema)

    numsDF.registerTempTable("mytable")
    sqlContext.sql("select id,avg(num) from mytable  group by id").collect().foreach { x => println(s"id:${x(0)},avg:${x(1)}") }
    sqlContext.udf.register("myAvg", new MyAvg)
    sqlContext.sql("select id,myAvg(num) from mytable group by id ").collect().foreach { x => println(s"id:${x(0)},avg:${x(1)}") }

    sc.stop()
  }
}

使用原生的avg和自定义的avg的输出的结果一致:

id:a,avg:1.6
id:b,avg:1.1

id:a,avg:1.6
id:b,avg:1.1


  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值