Hive中用户自定义函数UDF UDTF UDAF

概述      

       Hive 自带了一些函数,比如:max/min 等,但是数量有限,当 Hive 提供的内置函数无法满足你的业务处理需要时,此时就可以考虑使用用户自定义

函数(UDF:user-defined function)根据用户自定义函数类别分为以下三种:

  1. UDF(User-Defined-Function)  一进一出
  2. UDAF(User-Defined Aggregation Function)  聚集函数,多进一出   类似于:count/max/min
  3. UDTF(User-Defined Table-Generating Functions)  一进多出

创建UDF步骤

1.创建工程,引入maven包

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>com.hl</groupId>
    <artifactId>hive</artifactId>
    <version>0.1</version>
    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.apache.hive</groupId>
            <artifactId>hive-exec</artifactId>
            <version>2.3.7</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>${maven.compiler.source}</source>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

2.继承相关抽象类,实现UDF

一进一出 UDF

package hive.User_Defined_Functions;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.Text;
/**
 * @describe:  用户自定义函数 一进一出
 *步骤:
 *  1.继承 org.apache.hadoop.hive.ql.UDF,需要实现 evaluate 函数;evaluate 函数支持重载;
 *  2.hive命令行添加 jar
 *      add jar linux_jar_path
 *  3.hive命令行创建function
 *      create [temporary] function [dbname.]function_name AS class_name;
 *  hive的命令行删除function
 *      Drop [temporary] [dbname.]function_name; 	function [if exists]
 */
public class MyUDF extends UDF {

    /**
     * 实现 evaluate 函数;evaluate 函数支持重载;
     * 注意: UDF必须要有返回类型,可以返回 null,但是返回类型不能为 void
     */
    public Text evaluate(final Text s) {
        if (s == null) {
            return null;
        }
        return new Text(s.toString().toLowerCase());
    }
}
/*
 * 测试:
 *  add jar udf.jar
 *  create temporary function mylower as "hive.User_Defined_Functions.MyUDF";
 *  select mylower(name) from student;
 */

一进多出UDTF

package hive.User_Defined_Functions;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import java.util.ArrayList;
import java.util.List;

/**
 * @describe: 通用的用户定义表生成函数(UDTF),为单个输入行生成可变数量的输出行
 * 案例:自定义一个 UDTF 实现将一个任意分割符的字符串切割成独立的单词
 *   输入:"hello,world,hadoop,hive"
 *   命令:Myudtf(line, ",")
 *   输出:hello world hadoop hive
 */
public class MyUDTF extends GenericUDTF {
    private ArrayList<String> outList = new ArrayList<>();
    @Override
    public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {
        //1.定义输出数据的列名和类型
        List<String> fieldNames = new ArrayList<>();
        List<ObjectInspector> fieldOIs = new ArrayList<>();
        //2.添加输出数据的列名和类型
        fieldNames.add("lineToWord");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(Object[] args) throws HiveException {
        //1.获取原始数据
        String arg = args[0].toString();
        //2.获取数据传入的第二个参数,此处为分隔符
        String splitKey = args[1].toString();
        //3.将原始数据按照传入的分隔符进行切分
        String[] fields = arg.split(splitKey);
        //4.遍历切分后的结果,并写出
        for (String field : fields) {
            //集合为复用的,首先清空集合
            outList.clear();
            //将每一个单词添加至集合
            outList.add(field);
            //将集合内容写出
            forward(outList);
        }
    }
    @Override
    public void close() throws HiveException {
    }
}

多进一出UDAF

UDAF需要实现下面的类

org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator

demo1 多个数字求和

package hive.User_Defined_Functions;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.LongWritable;

/**
 * @describe: 多进一出,两个数求sum
 */
public class MyUDAFSUM extends GenericUDAFEvaluator {
    //输入数据序列化
    private PrimitiveObjectInspector inputOI;
    //返回值
    private LongWritable result;
    private boolean warned;

    public MyUDAFSUM() {
        this.warned = false;
    }

    /**
     * 这个方法返回udaf的返回类型。这里定义返回类型为long
     */
    @Override
    public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
        assert (parameters.length == 1);
        //输入类型初始化
        super.init(m, parameters);
        result = new LongWritable(0L);
        inputOI = ((PrimitiveObjectInspector) parameters[0]);
        //输出
        return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
    }

    /**
     * 创建新的聚合计算需要的内存,用来存储mapper,combiner,reducer运算过程中的相加总和。
     *
     * @return GenericUDAFEvaluator.AggregationBuffer 用来存储聚集过程期间的结果
     */
    @Override
    public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
        SumLong result = new SumLong();
        reset(result);
        return result;
    }

    /**
     * 重置聚合结果,可以方便重复使用相同的聚合
     */
    @Override
    public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
        SumLong myagg = (SumLong) agg;
        myagg.empty = true;
        myagg.sum = 0L;
    }

    /**
     * 迭代遍历原始的输入数据
     */
    @Override
    public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
        assert (parameters.length == 1);
        try {
            merge(agg, parameters[0]);
        } catch (NumberFormatException e) {
            if (!(this.warned)) {
                this.warned = true;
            }
        }
    }

    /**
     * 合并部分聚合结果
     * <p>
     * 说明:
     * PrimitiveObjectInspectorUtils.getLong(数据, 数据的数据类型): 获取对应数据类型的数值
     */
    @Override
    public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
        if (partial != null) {
            SumLong myagg = (SumLong) agg;
            myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI);
            myagg.empty = false;
        }
    }

    /**
     * 获得部分聚合的结果
     */
    @Override
    public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
        return terminate(agg);
    }

    /**
     * 获得最终的聚合结果
     */
    @Override
    public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) {
        SumLong myagg = (SumLong) agg;
        if (myagg.empty) {
            return null;
        }
        result.set(myagg.sum);
        return result;
    }

    /**
     * 存储sum值的类
     */
    @GenericUDAFEvaluator.AggregationType(estimable = true)
    static class SumLong extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        boolean empty;
        long sum;

        //预计存储空间大小 12个字节
        public int estimate() {
            return 12;
        }
    }
}

demo2多个数字求平均值

package hive.User_Defined_Functions;

import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;

import java.util.ArrayList;

/**
 * @describe:
 * 案例:计算多个数的平均数
 *    sum= 数据值的叠加
 *    count = 数据的个数
 *    平均数=sum/count
 *
 */
public class MyUDAFAverage extends AbstractGenericUDAFResolver {

    /**
     * 入参数据类型的校验,如果参数校验通过则直接返回数据聚合处理结果
     * @param parameters 参数类型
     */
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
        }

        if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + parameters[0].getTypeName() + " is passed.");
        }
        switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case FLOAT:
            case DOUBLE:
            case STRING:
            case TIMESTAMP: return new GenericUDAFAverageEvaluator();
            case BOOLEAN:
            default: throw new UDFArgumentTypeException(0, "Only numeric or string type arguments are accepted but " + parameters[0].getTypeName() + " is passed.");
        }
    }

    /**
     * GenericUDAFAverageEvaluator. 平均数计算
     * 自定义静态内部类:数据处理类,继承GenericUDAFEvaluator抽象类
     */
    public static class GenericUDAFAverageEvaluator extends GenericUDAFEvaluator {
        //原数据
        PrimitiveObjectInspector inputOI;
        //中间数据 count sum整体结构
        StructObjectInspector soi;
        //输入的count数据结构
        StructField countField;
        //输入的sum 数据结构
        StructField sumField;
        LongObjectInspector countFieldOI;
        DoubleObjectInspector sumFieldOI;

        //定义全局输出数据的类型,用于存储实际数据
        Object[] partialResult;
        //最终输出结果
        DoubleWritable result;

        /*
         * 初始化:对各个模式处理过程,提取输入数据类型OI,返回输出数据类型OI
         * .每个模式(Mode)都会执行初始化
         * 1.输入参数parameters:
         * .1.1.对于PARTIAL1 和COMPLETE模式来说,是原始数据(单值)
         *    .设定了iterate()方法的输入参数的类型OI为:
         *    .		 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
         *    .		 通过输入OI实例解析输入参数值
         * .1.2.对于PARTIAL2 和FINAL模式来说,是模式聚合数据(双值)
         *    .设定了merge()方法的输入参数的类型OI为:
         *    .		 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
         *    .		 通过输入OI实例解析输入参数值
         * 2.返回值OI:
         * .2.1.对于PARTIAL1 和PARTIAL2模式来说,是设定了方法terminatePartial()返回值的OI实例
         *    .输出OI为 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
         * .2.2.对于FINAL 和COMPLETE模式来说,是设定了方法terminate()返回值的OI实例
         *    .输出OI为 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
         */
        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 1);

            //输入初始化
            super.init(mode, parameters);
            //原始数据到部分聚集数据 || 原始数据到所有剧集数据
            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                //聚集到聚集 || 聚集到最终结果
                soi = (StructObjectInspector) parameters[0];
                countField = soi.getStructFieldRef("count");
                sumField = soi.getStructFieldRef("sum");
                //数组中的每个数据,需要其各自的基本类型OI实例解析
                countFieldOI = (LongObjectInspector) countField.getFieldObjectInspector();
                sumFieldOI = (DoubleObjectInspector) sumField.getFieldObjectInspector();
            }

            // 输出中间过程是有sum 和count 是数组
            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
                //部分聚合结果是一个数组
                partialResult = new Object[2];
                partialResult[0] = new LongWritable(0);
                partialResult[1] = new DoubleWritable(0);
                //构造Struct的OI实例,用于设定聚合结果数组的类型,需要字段名List和字段类型List作为参数来构造
                ArrayList<String> fname = new ArrayList<>();
                fname.add("count");
                fname.add("sum");
                ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                //注:此处的两个OI类型 描述的是 partialResult[] 的两个类型,故需一致
                foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
            } else {
                //FINAL 最终聚合结果为一个数值,并用基本类型OI设定其类型
                result = new DoubleWritable(0);
                return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            }
        }

        /*
         * 聚合数据缓存存储结构
         */
        static class AverageAgg implements AggregationBuffer {
            long count;
            double sum;
        }

        /**
         * 创建新的聚合计算需要的内存
         */
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            AverageAgg result = new AverageAgg();
            reset(result);
            return result;
        }

        /**
         *  重置聚合结果,可以方便重复使用相同的聚合
         */
        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            myagg.count = 0;
            myagg.sum = 0;
        }

        boolean warned = false;

        /*
         * 遍历原始数据
         */
        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) {
            assert (parameters.length == 1);
            Object p = parameters[0];
            if (p != null) {
                AverageAgg myagg = (AverageAgg) agg;
                try {
                    //通过基本数据类型OI解析Object p的值
                    double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
                    myagg.count++;
                    myagg.sum += v;
                } catch (NumberFormatException e) {
                    if (!warned) {
                        warned = true;
                    }
                }
            }
        }

        /*
         * 得出部分聚合结果
         */
        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            ((LongWritable) partialResult[0]).set(myagg.count);
            ((DoubleWritable) partialResult[1]).set(myagg.sum);
            return partialResult;
        }

        /*
         * 合并部分聚合结果
         * 注:Object[] 是 Object 的子类,此处 partial 为 Object[]数组
         */
        @Override
        public void merge(AggregationBuffer agg, Object partial) {
            if (partial != null) {
                AverageAgg myagg = (AverageAgg) agg;
                //通过StandardStructObjectInspector实例,分解出 partial 数组元素值
                Object partialCount = soi.getStructFieldData(partial, countField);
                Object partialSum = soi.getStructFieldData(partial, sumField);
                //通过基本数据类型的OI实例解析Object的值
                myagg.count += countFieldOI.get(partialCount);
                myagg.sum += sumFieldOI.get(partialSum);
            }
        }

        /*
         * 得出最终聚合结果
         */
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            if (myagg.count == 0) {
                return null;
            } else {
                result.set(myagg.sum / myagg.count);
                return result;
            }
        }
    }

}

3.打jar包并add 到hive中

add jar jar_path

4.创建函数

create [temporary] function [dbname.]function_name AS class_name; 

 

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值