udaf开发小结

1 篇文章 0 订阅

背景:

在使用group by的SQL中,多进一出也是常见场景,例如hive自带的avg、sum都是多进一出,这个场景的自定义函数叫做用户自定义聚合函数(User Defiend Aggregate Function,UDAF)。

  • 在一些旧版的教程和文档中,都会提到UDAF开发的关键是继承UDAF.java;
  • 打开hive-exec的1.2.2版本源码,却发现UDAF类已被注解为Deprecated
  • UDAF类被废弃后,推荐的替代品有两种:实现GenericUDAFResolver2接口,或者继承AbstractGenericUDAFResolver类;两者是一样的,后者本身就是实现了前者的接口。

逻辑:

编写通用型UDAF需要两个类:解析器和计算器。解析器负责UDAF的参数检查,操作符的重载以及对于给定的一组参数类型来查找正确的计算器,建议继承AbstractGenericUDAFResolver类,具体实现如下:

@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
        throws SemanticException {
    if (parameters.length != 1) {
        throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
    }
    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
        throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted.");
    }
    return new CollectListUDAFEvaluator();
}

计算器实现具体的计算逻辑,需要继承GenericUDAFEvaluator抽象类。
计算器有4种模式,由枚举类GenericUDAFEvaluator.Mode定义:

public static enum Mode {
    PARTIAL1, //从原始数据到部分聚合数据的过程(map阶段),将调用iterate()和terminatePartial()方法。
    PARTIAL2, //从部分聚合数据到部分聚合数据的过程(map端的combiner阶段),将调用merge() 和terminatePartial()方法。   
    FINAL,    //从部分聚合数据到全部聚合的过程(reduce阶段),将调用merge()和 terminate()方法。
    COMPLETE  //从原始数据直接到全部聚合的过程(表示只有map,没有reduce,map端直接出结果),将调用merge() 和 terminate()方法。
};

计算器必须实现的方法:
1、getNewAggregationBuffer():返回存储临时聚合结果的AggregationBuffer对象。
2、reset(AggregationBuffer agg):重置聚合结果对象,以支持mapper和reducer的重用。
3、iterate(AggregationBuffer agg,Object[] parameters):迭代处理原始数据parameters并保存到agg中。
4、terminatePartial(AggregationBuffer agg):以持久化的方式返回agg表示的部分聚合结果,这里的持久化意味着返回值只能Java基础类型、数组、基础类型包装器、Hadoop的Writables、Lists和Maps。
5、merge(AggregationBuffer agg,Object partial):合并由partial表示的部分聚合结果到agg中。
6、terminate(AggregationBuffer agg):返回最终结果。

通常还需要覆盖初始化方法ObjectInspector init(Mode m,ObjectInspector[] parameters),需要注意的是,在不同的模式下parameters的含义是不同的,比如m为 PARTIAL1 和 COMPLETE 时,parameters为原始数据;m为 PARTIAL2 和 FINAL 时,parameters仅为部分聚合数据(只有一个元素)。在 PARTIAL1 和 PARTIAL2 模式下,ObjectInspector  用于terminatePartial方法的返回值,在FINAL和COMPLETE模式下ObjectInspector 用于terminate方法的返回值。
 

下图对每个阶段调用了哪些方法说得很清楚:

下图对顺序执行的三个阶段和涉及方法做了详细说明:

 

示例:

下面实现一个计算器,按分组中元素的出现次数降序排序,并将每个元素的在分组中的出现次数也一起返回,格式为:
 [data1, num1, data2, num2, ...]

public static class CollectListUDAFEvaluator extends GenericUDAFEvaluator {
    protected PrimitiveObjectInspector inputKeyOI;
    protected StandardListObjectInspector loi;
    protected StandardListObjectInspector internalMergeOI;
    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
            throws HiveException {
        super.init(m, parameters);
        if (m == Mode.PARTIAL1) {
            inputKeyOI = (PrimitiveObjectInspector) parameters[0];
            return ObjectInspectorFactory.getStandardListObjectInspector(
                    ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI));
        } else {
            if ( parameters[0] instanceof StandardListObjectInspector ) {
                internalMergeOI = (StandardListObjectInspector) parameters[0];
                inputKeyOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
                loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
                return loi;
            } else {
                inputKeyOI = (PrimitiveObjectInspector) parameters[0];
                return ObjectInspectorFactory.getStandardListObjectInspector(
                        ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI));
            }
        }
    }
 
    static class MkListAggregationBuffer implements AggregationBuffer {
        List<Object> container = Lists.newArrayList();
    }
    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
        ((MkListAggregationBuffer) agg).container.clear();
    }
    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
        MkListAggregationBuffer ret = new MkListAggregationBuffer();
        return ret;
    }
    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
            throws HiveException {
        if(parameters == null || parameters.length != 1){
            return;
        }
        Object key = parameters[0];
        if (key != null) {
            MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
            putIntoList(key, myagg.container);
        }
    }
 
    private void putIntoList(Object key, List<Object> container) {
        Object pCopy = ObjectInspectorUtils.copyToStandardObject(key,  this.inputKeyOI);
        container.add(pCopy);
    }
 
    @Override
    public Object terminatePartial(AggregationBuffer agg)
            throws HiveException {
        MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
        List<Object> ret = Lists.newArrayList(myagg.container);
        return ret;
    }
    @Override
    public void merge(AggregationBuffer agg, Object partial)
            throws HiveException {
        if(partial == null){
            return;
        }
        MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
        List<Object> partialResult = (List<Object>) internalMergeOI.getList(partial);
        for (Object ob: partialResult) {
            putIntoList(ob, myagg.container);
        }
        return;
    }
 
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
        MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
        Map<Text, Integer> map = Maps.newHashMap();
        for (int i = 0; i< myagg.container.size() ; i++){
            Text key = (Text) myagg.container.get(i);
            if (map.containsKey(key)) {
                map.put(key, map.get(key) + 1);
            }else{
                map.put(key, 1);
            }
        }
        List<Map.Entry<Text, Integer>> listData = Lists.newArrayList(map.entrySet());
        Collections.sort(listData, new Comparator<Map.Entry<Text, Integer>>() {
            public int compare(Map.Entry<Text, Integer> o1, Map.Entry<Text, Integer> o2) {
                if (o1.getValue() < o2.getValue())
                    return 1;
                else if (o1.getValue() == o2.getValue())
                    return 0;
                else
                    return -1;
            }
        });
 
        List<Object> ret =  Lists.newArrayList();
        for(Map.Entry<Text, Integer> entry : listData){
            ret.add(entry.getKey());
            ret.add(new Text(entry.getValue().toString()));
        }
        return ret;
    }
}

部署:

1.打依赖包并上传,如hdfs://hadoop/user/hive/jars/doris-udaf-test.jar

2.创建测试函数:CREATE FUNCTION hive_udf.collect_list_test as 'com.xxx.hiveudf.userpathudaf.UserPath' USING JAR 'hdfs://hadoop/user/hive/jars/doris-udaf-test.jar';

3.创建测试表:create table test (...)

4.查询验证:select id, collect_list(value) from test group by id;

5.查看mr yarn任务日志:yarn logs -applicationId application_1655870182598_24029 >a.log

参考链接:

hive学习笔记之十:用户自定义聚合函数(UDAF) - 腾讯云开发者社区-腾讯云

Hive通用型自定义聚合函数(UDAF)_沧南的博客-CSDN博客_hive 自定义udaf

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值