SparkSQL创建RDD:UDAF(UserDefinedAggregatedFunction)用户自定义聚合函数【Java版纯代码】

要实现8个方法,8个方法中,最为重要的有3个:

initialize:初始化,在给,map端每一个分区的每一个key进行初始化,给0

update:在map端聚合

merge: 在reduce端聚合

Java版代码:

package com.bjsxt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.hive.ql.parse.HiveParser_SelectClauseParser.selectClause_return;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
/**
 * 用户自定义聚合函数
 * @author Administrator
 *
 */
public class UDAF {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("test").setMaster("local");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        List<String> list = Arrays.asList("zhangsan", "lisi", "wangwu", "zhangsan", "zhangsan", "lisi", "wangwu");
        JavaRDD<String> parallelize = sc.parallelize(list);
        JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {

            /**
             * map是一对一的类型
             * 进去的是String类型,出来的是row类型
             */
            @Override
            public Row call(String s) throws Exception {

                return RowFactory.create(s);
            }
        });
        List<StructField> fields = new ArrayList<StructField>();
        /**
         * 创建名为name的区域,类型为String
         */
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        /**
         * 创建schema        
         */
        StructType schema = DataTypes.createStructType(fields);
        /**
         * 将rdd和schema相聚和
         */
        DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
        /**
         * 创建一个用户名为user的表格
         */
        df.registerTempTable("user");
        /**
         * 注册一个UDAF函数,实现统计相同值的个数 
         * 注意: 这里可以自定义一个类,继承UserDefinedAggregatedFunction类也是可以的
         */
        sqlContext.udf().register("StringCount", new UserDefinedAggregateFunction() {
            /**
             * initialize相当于初始化:
             * map端每个元素的初试值都为零
             * reduce端的每个元素的初始值都为零
             * update相当于map端的聚合
             * merge相当于reduce端的聚合
             *map端的 merge的好处:
             *1.减少了suffer磁盘的数据量
             *2.减少了reduce端拉取的数据量
             *3.减少了reduce端的聚合次数
             */
            /**
             * 方法 用户自定义聚合函数
             */

            /**
             * ** 在进行聚合造作的时候,所要处理的数据的结果的类型
             * 
             * @return
             */

            @Override
            public StructType bufferSchema() {
                // TODO Auto-generated method stub
                return DataTypes.createStructType(
                        Arrays.asList(DataTypes.createStructField("bfferxx", DataTypes.IntegerType, true)));
            }

            /**
             * 指定UDAF函数计算后,返回的结果类型
             * 
             * @return
             */
            @Override
            public DataType dataType() {
                // TODO Auto-generated method stub
                return DataTypes.IntegerType;
            }

            /**
             * 确保一致性,一般用true 用以标记针对给定的一组输入 UDAF是否纵使生成相同的结果,
             * 
             * @return
             */
            @Override
            public boolean deterministic() {
                // TODO Auto-generated method stub
                return true;
            }

            /**
             * 最后返回一个和dataType方法的类型要一致的类型 返回UDAF最后的计算结果
             * 
             * @param arg0
             * @return row是已经分好组的key
             */
            @Override
            public Object evaluate(Row row) {

                return row.getInt(0);
            }

            /**
             * 初始化一个内部自定义的值 在Aggregate之前每组数据的初始化结果
             * 
             * @param buffer
             * 在map端每一个分区,中的每一个key做初始化,里边的值都为零
             * initialize不仅作用在map端,初始化元素为零
             * 而且还作用在reduce端,初始化reduce端的每个元素的值也为零
             */
            @Override
            public void initialize(MutableAggregationBuffer buffer) {
                buffer.update(0, 0);

            }

            /**
             * 指定输入字段的字段及类型
             * 
             * @return
             */
            @Override
            public StructType inputSchema() {

                return DataTypes.createStructType(
                        Arrays.asList(DataTypes.createStructField("namexxx", DataTypes.StringType, true)

                        ));

            }

            /**
             * 合并update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据会在多个节点上处理
             * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来 buffer1.getInt(0):大聚合的时候,上一次聚合后的值
             * buffer2.getInt(0):这次计算传入进来的update的结果 这里即是:最后在分布式节点上完成传后需要进行全局级别的merge操作
             * 
             * @param arg0
             * @param arg1
             *merge 作用在reduce端,将所有的数据拉取在一起
             *reduce端跨分区,跨节点
             * 
             */
            @Override
            public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
                buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
                /**
                 * 难不成各个节点的名字叫buffer吗?
                 */
            }

            /**
             * 更新,可以认为是一个一个地将组内的字段传递进来的,实现拼接的逻辑 buffer.getInt(0)获取的是上一次聚合后的值
             * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小的聚合 大聚合发生在reduce端
             * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后得值如何进行计算
             * 
             * @param arg0
             * @param arg1
             *            update相当于map端的聚合 作用在每一个分区的每一个小组
             */
            @Override
            public void update(MutableAggregationBuffer buffer, Row arg1) {
                buffer.update(0, buffer.getInt(0) + 1);

            }

        });
        /**
         * 真正的SQL查询语句
         */
        sqlContext.sql("select name,StringCount(name) as strCount from user group by name").show();
        sc.stop();
    }
}


Scala版代码:

class MyUDAF extends UserDefinedAggregateFunction  {
  // 聚合操作时,所处理的数据的类型
  def bufferSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("aaa", IntegerType, true)))
  }
  // 最终函数返回值的类型
  def dataType: DataType = {
    DataTypes.IntegerType
  }

  def deterministic: Boolean = {
    true
  }
  // 最后返回一个最终的聚合值     要和dataType的类型一一对应
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
  // 为每个分组的数据执行初始化值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
     buffer(0) = 0
  }
  //输入数据的类型
  def inputSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("input", StringType, true)))
  }
  // 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0) 
  }
  // 每个组,有新的值进来的时候,进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0)+1
  }
}

object UDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setMaster("local").setAppName("udaf")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.makeRDD(Array("zhangsan","lisi","wangwu","zhangsan","lisi"))
    val rowRDD = rdd.map { x => {RowFactory.create(x)} }
    
    val schema = DataTypes.createStructType(Array(DataTypes.createStructField("name", StringType, true)))
    val df = sqlContext.createDataFrame(rowRDD, schema)
    df.show()
    df.registerTempTable("user")
    /**
     * 注册一个udaf函数
     */
    sqlContext.udf.register("StringCount", new MyUDAF())
    sqlContext.sql("select name ,StringCount(name) from user group by name").show()
    sc.stop()
  }
}

亲,鼓励一下我呗。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值