SparkSQL中UDAF案例分析

SparkSQL中UDAF案例分析

1、统计单词的个数

package com.bynear.spark_sql;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;


public class Spark_UDAF extends UserDefinedAggregateFunction {
    /**
     * inputSchema指的是输入的数据类型
     *
     * @return
     */
    @Override
    public StructType inputSchema() {
        ArrayList<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("str", DataTypes.StringType, true));
        return DataTypes.createStructType(fields);
    }

    /**
     * bufferSchema指的是  中间进行聚合时  所处理的数据类型
     *
     * @return
     */
    @Override
    public StructType bufferSchema() {
        ArrayList<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("count", DataTypes.IntegerType, true));
        return DataTypes.createStructType(fields);
    }

    /**
     * dataType指的是函数返回值的类型
     *
     * @return
     */
    @Override
    public DataType dataType() {
        return DataTypes.IntegerType;
    }

    /**
     * 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
     *
     * @return
     */
    @Override
    public boolean deterministic() {
        return true;
    }

    /**
     * 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer
     * 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2     * 不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
     * 为每个分组的数据执行初始化操作
     *
     * @param buffer
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0);
    }

    /**
     * 用输入数据input更新buffer,类似于combineByKey
     * 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
     *
     * @param buffer
     * @param input
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0, Integer.valueOf(buffer.getAs(0).toString()) + 1);
    }

    /**
     * 合并两个buffer,buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
     * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节
     * 由于spark是分布式的,所以每一分组的数据,可能会在不同的节点上进行局部聚合,就是update
     * 但是 最后一个分组,在各个节点上的聚合值,要进行merge 也就是合并
     *
     * @param buffer1
     * @param buffer2
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0, Integer.valueOf(buffer1.getAs(0).toString()) + Integer.valueOf(buffer2.getAs(0).toString()));
    }

    /**
     * 只的是 一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
     *
     * @param buffer
     * @return
     */
    @Override
    public Object evaluate(Row buffer) {
        return buffer.getInt(0);
    }
}

package com.bynear.spark_sql;

import com.clearspring.analytics.util.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.Arrays;
import java.util.List;

public class UDAF {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("UDAF").setMaster("local");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        List<String> nameList = Arrays.asList("xiaoming", "xiaoming", "刘德华","古天乐","feifei", "feifei", "feifei", "katong");
        //转换为javaRDD
        JavaRDD<String> nameRDD = sc.parallelize(nameList, 3);
        //转换为JavaRDD<Row>
        JavaRDD<Row> nameRowRDD = nameRDD.map(new Function<String, Row>() {
            public Row call(String name) throws Exception {
                return RowFactory.create(name);
            }
        });
        List<StructField> fields = Lists.newArrayList();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType structType = DataTypes.createStructType(fields);
        DataFrame namesDF = sqlContext.createDataFrame(nameRowRDD, structType);
        namesDF.registerTempTable("names");
        sqlContext.udf().register("countString", new Spark_UDAF());
        sqlContext.sql("select name,countString(name) as count  from names group by name").show();
        List<Row> rows = sqlContext.sql("select name,countString(name) as count  from names group by name").javaRDD().collect();
        for (Row row : rows) {
            System.out.println(row);
        }
    }
}
运行结果:

+--------+-----+
|    name|count|
+--------+-----+
|  feifei|    3|
|xiaoming|    2|
|     刘德华|    1|
|  katong|    1|
|     古天乐|    1|
+--------+-----+

2、统计某品牌价格的平均值

package com.bynear.spark_sql;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;

public class MyUDAF extends UserDefinedAggregateFunction {
    private StructType inputSchema;
    private StructType bufferSchema;

    public MyUDAF() {
        ArrayList<StructField> inputFields = new ArrayList<StructField>();
        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.DoubleType, true));
        inputSchema = DataTypes.createStructType(inputFields);

        ArrayList<StructField> bufferFields = new ArrayList<StructField>();
        bufferFields.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true));
        bufferFields.add(DataTypes.createStructField("count", DataTypes.DoubleType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);
    }

    @Override
    public StructType inputSchema() {
        return inputSchema;
    }

    @Override
    public StructType bufferSchema() {
        return bufferSchema;
    }

    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    @Override
    public boolean deterministic() {
        return true;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
//        缓存区两个分组  分组编号为0 求和sum   初始化值为0
//                     分组编号为1 count   初始化值为0
        buffer.update(0, 0.0);
        buffer.update(1, 0.0);
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        //如果input的索引值为0的值不为0
        if (!input.isNullAt(0)) {
//            两个分组分别进行更新数据!分组编号0  求和sum  缓存区的值 +  输入放入值
            double updatesum = buffer.getDouble(0) + input.getDouble(0);
//                                 分组编号1  count  缓存区的个数 + 1
            double updatecount = buffer.getDouble(1) + 1;
            buffer.update(0, updatesum);
            buffer.update(1, updatecount);
        }

    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        double metgesum = buffer1.getDouble(0) + buffer2.getDouble(0);
        double mergecount = buffer1.getDouble(1) + buffer2.getDouble(1);
        buffer1.update(0, metgesum);
        buffer1.update(1, mergecount);
    }

    @Override
    public Object evaluate(Row buffer) {
        return buffer.getDouble(0) / buffer.getDouble(1);
    }
}

package com.bynear.spark_sql;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.math.BigDecimal;
import java.util.ArrayList;

public class MyUDAF_SQL {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("myUDAF").setMaster("local");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(jsc);
        JavaRDD<String> lines = jsc.textFile("C://Users//Administrator//Desktop//fastJSon//sales.txt");
        JavaRDD<Row> map = lines.map(new Function<String, Row>() {
            @Override
            public Row call(String line) throws Exception {
                String[] Linesplit = line.split(",");
                return RowFactory.create(String.valueOf(Linesplit[0]), Double.valueOf(Linesplit[1]));
            }
        });
        ArrayList<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("salary", DataTypes.DoubleType, true));
        StructType structType = DataTypes.createStructType(fields);
        DataFrame df = sqlContext.createDataFrame(map, structType);
        sqlContext.udf().register("myAverage", new MyUDAF());
        df.registerTempTable("zjs_table");

        df.show();

        sqlContext.udf().register("twoDecimal", new UDF1<Double, Double>() {
            @Override
            public Double call(Double in) throws Exception {
                BigDecimal b = new BigDecimal(in);
                double res = b.setScale(2, BigDecimal.ROUND_HALF_DOWN).doubleValue();
                return res;
            }
        }, DataTypes.DoubleType);

        DataFrame resultDF = sqlContext.sql("select name,twoDecimal(myAverage(salary)) as 平均值 from zjs_table group by name ");
        resultDF.show();

    }
}

文本:

三星,1542
三星,1548
三星,8456
三星,8866
中兴,1856
中兴,1752
苹果,1500
苹果,2500
苹果,3500
苹果,4500
苹果,5500

运行结果:

+----+-------+
|name| salary|
+----+-------+
|  三星|12345.0|
|  三星| 4521.0|
|  三星| 7895.0|
|  华为| 5421.0|
|  华为| 4521.0|
|  华为| 5648.0|
|  苹果|12548.0|
|  苹果| 7856.0|
|  苹果|45217.0|
|  苹果|89654.0|
+----+-------+

+----+--------+
|name|     平均值|
+----+--------+
|  三星| 8253.67|
|  华为| 5196.67|
|  苹果|38818.75|
+----+--------+

注意点:文本的编码格式,以及Java代码中DataTypes.DoubleType。。。。








  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值