Hive自定义函数


前言

自定义函数是为了扩充HIVE查询中的函数,自定义函数可以创建为临时函数、永久函数。

1.引入库

下面所有章节需要用到的引入jar包

<dependency>
	<group>org.apache.hive</group>
	<artifactId>hive-exec</artifactId>
	<version>2.3.9</version>
</dependency>
<dependency>
	<group>org.apache.hive</group>
	<artifactId>hive-jdbc</artifactId>
	<version>2.3.9</version>
</dependency>

2.通过命令查询函数

查询所有函数
SHOW FUNCTIONS;

+---------------------------+
|         tab_name          |
+---------------------------+
| avg						|
| uuid                      |
| sum                       |
+---------------------------+

通过模糊匹配查询函数
SHOW FUNCTIONS LIKE ‘%avg%’;

+---------------------------+
|         tab_name          |
+---------------------------+
| avg						|
+---------------------------+

查询函数简短说明
DESCRIBE FUNCTION avg;

+------------------------------------------------+
|                   tab_name                     |
+------------------------------------------------+
|  agv(x) - Returns the mean of a set of numbers |						|
+------------------------------------------------+

#查看函数详细说明
DESCRIBE FUNCTION EXTENDED avg;

+------------------------------------------------+
|                   tab_name                     |
+------------------------------------------------+
|  agv(x) - Returns the mean of a set of numbers |						|
| Function class:org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage |
| Function Type: BUILTIN                         |
+------------------------------------------------+

3.自定义函数添加到Hive函数库中

添加自定义函数为永久函数

--切换到指定的database中
use test;
--如果存在则删除当前函数
drop function if exists fn_test_avg;
--创建函数:需要指定函数名称、类名称、jar在HDFS的路径
create function fun_test_avg as 'com.study.udaf.TestAVG' using jar 'hdfs://nameservicetentant/user/study/jar/study.jar';

添加自定义函数为临时函数

--切换到指定的database中
use test;
--如果存在则删除当前函数
drop temporary function if exists fn_test_avg;
--创建函数:需要指定函数名称、类名称、jar在HDFS的路径
create temporary function fun_test_avg as 'com.study.udaf.TestAVG' using jar 'hdfs://nameservicetentant/user/study/jar/study.jar';

一、标准函数:UDF

一行数据中的一列或是多列数据作为参数,然后返回结果是一个值的函数。
标准函数的内置函数如:绝对值abs()、字符串反转reverse()。

基于UDF

单个输入参数,且输入参数数据格式和输出参数数据格式一致可以采用

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by Fernflower decompiler)
//

package org.apache.hadoop.hive.ql.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils;
import org.apache.hadoop.io.Text;
/**
* name: 函数名称
* value: 函数返回值
* extended: 函数执行示例
*/
@Description(
    name = "reverse",
    value = "_FUNC_(str) - reverse str",
    extended = "Example:\n  > SELECT _FUNC_('Facebook') FROM src LIMIT 1;\n  'koobecaF'"
)
public class UDFReverse extends UDF {
    private final Text result = new Text();

    public UDFReverse() {
    }

    private void reverse(byte[] arr, int first, int last) {
        for(int i = 0; i < (last - first + 1) / 2; ++i) {
            byte temp = arr[last - i];
            arr[last - i] = arr[first + i];
            arr[first + i] = temp;
        }

    }
	/**
	* 当前的方式是通过父类初始化方法,创建了DefaultUDFMethodResolver
	* 而执行方法为:中的 evaluate 调用到当前的方法。
	*  public Method getEvalMethod(List<TypeInfo> argClasses) throws UDFArgumentException {
        return FunctionRegistry.getMethodInternal(this.udfClass, "evaluate", false, argClasses);
    }
	*/
    public Text evaluate(Text s) {
        if (s == null) {
            return null;
        } else {
            this.result.set(s);
            byte[] data = this.result.getBytes();
            int prev = 0;

            for(int i = 1; i < this.result.getLength(); ++i) {
                if (GenericUDFUtils.isUtfStartByte(data[i])) {
                    this.reverse(data, prev, i - 1);
                    prev = i;
                }
            }

            this.reverse(data, prev, this.result.getLength() - 1);
            this.reverse(data, 0, this.result.getLength() - 1);
            return this.result;
        }
    }
}

GenericUDF单个参数输入

定义函数单个参数输入,单个值输出格式。可以对输入参数格式和输入数据类型校验


package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsDoubleToDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
/**
* name: 函数名称
* value: 函数返回值
* extended: 函数执行示例
*/
@Description(
    name = "abs",
    value = "_FUNC_(x) - returns the absolute value of x",
    extended = "Example:\n  > SELECT _FUNC_(0) FROM src LIMIT 1;\n  0\n  > SELECT _FUNC_(-5) FROM src LIMIT 1;\n  5"
)
@VectorizedExpressions({FuncAbsLongToLong.class, FuncAbsDoubleToDouble.class, FuncAbsDecimalToDecimal.class})
public class GenericUDFAbs extends GenericUDF {
    private transient PrimitiveCategory inputType;
    private final DoubleWritable resultDouble = new DoubleWritable();
    private final LongWritable resultLong = new LongWritable();
    private final IntWritable resultInt = new IntWritable();
    private final HiveDecimalWritable resultDecimal = new HiveDecimalWritable();
    private transient PrimitiveObjectInspector argumentOI;
    private transient Converter inputConverter;

    public GenericUDFAbs() {
    }
	/**
	* 对输入参数的交易,以及输出参数格式的定义
	* arguments为函数输入参数
	*/
	@Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        if (arguments.length != 1) {
            throw new UDFArgumentLengthException("ABS() requires 1 argument, got " + arguments.length);
        } else if (arguments[0].getCategory() != Category.PRIMITIVE) {
            throw new UDFArgumentException("ABS only takes primitive types, got " + arguments[0].getTypeName());
        } else {
            this.argumentOI = (PrimitiveObjectInspector)arguments[0];
            this.inputType = this.argumentOI.getPrimitiveCategory();
            ObjectInspector outputOI = null;
            switch(this.inputType) {
            case SHORT:
            case BYTE:
            case INT:
                this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
                break;
            case LONG:
                this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
                break;
            case FLOAT:
            case STRING:
            case DOUBLE:
                this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
                break;
            case DECIMAL:
                outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(((PrimitiveObjectInspector)arguments[0]).getTypeInfo());
                this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], (ObjectInspector)outputOI);
                break;
            default:
                throw new UDFArgumentException("ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + this.inputType);
            }

            return (ObjectInspector)outputOI;
        }
    }
	/**
	* 执行方法,arguments为方法中输入的参数,
	* 注意返回值要和(initialize)定义的一致
	*/
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        Object valObject = arguments[0].get();
        if (valObject == null) {
            return null;
        } else {
            switch(this.inputType) {
            case SHORT:
            case BYTE:
            case INT:
                valObject = this.inputConverter.convert(valObject);
                this.resultInt.set(Math.abs(((IntWritable)valObject).get()));
                return this.resultInt;
            case LONG:
                valObject = this.inputConverter.convert(valObject);
                this.resultLong.set(Math.abs(((LongWritable)valObject).get()));
                return this.resultLong;
            case FLOAT:
            case STRING:
            case DOUBLE:
                valObject = this.inputConverter.convert(valObject);
                if (valObject == null) {
                    return null;
                }

                this.resultDouble.set(Math.abs(((DoubleWritable)valObject).get()));
                return this.resultDouble;
            case DECIMAL:
                HiveDecimalObjectInspector decimalOI = (HiveDecimalObjectInspector)this.argumentOI;
                HiveDecimalWritable val = decimalOI.getPrimitiveWritableObject(valObject);
                if (val != null) {
                    this.resultDecimal.set(val);
                    this.resultDecimal.mutateAbs();
                    val = this.resultDecimal;
                }

                return val;
            default:
                throw new UDFArgumentException("ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + this.inputType);
            }
        }
    }
    /**
    * 当前方法执行,输出为字符串
    */
	@Override
    public String getDisplayString(String[] children) {
        return this.getStandardDisplayString("abs", children);
    }
}

GenericUDF多个参数输入

函数多个参数输入,单个输出结果格式定义。可以对输入参数个数和数据类型校验

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.VoidObjectInspector;
/**
* name: 函数名称
* value: 函数返回值
*/
@Description(name = "array",
    value = "_FUNC_(n0, n1...) - Creates an array with the given elements ")
public class GenericUDFArray extends GenericUDF {
  private transient Converter[] converters;
  private transient ArrayList<Object> ret = new ArrayList<Object>();
  /**
  * 函数输出参数初始化,其中arguments为函数输入参数
  * 多个输入参数,转换为一个输出参数
  */
  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {

    GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(true);

    for (int i = 0; i < arguments.length; i++) {
      if (!returnOIResolver.update(arguments[i])) {
        throw new UDFArgumentTypeException(i, "Argument type \""
            + arguments[i].getTypeName()
            + "\" is different from preceding arguments. "
            + "Previous type was \"" + arguments[i - 1].getTypeName() + "\"");
      }
    }

    converters = new Converter[arguments.length];

    ObjectInspector returnOI =
        returnOIResolver.get(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    for (int i = 0; i < arguments.length; i++) {
      converters[i] = ObjectInspectorConverters.getConverter(arguments[i],
          returnOI);
    }

    return ObjectInspectorFactory.getStandardListObjectInspector(returnOI);
  }
  /**
  * 函数执行输入参数 arguments
  * 返回的Object要对应(initialize)中返回的ObjectInspector。
  */
  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    ret.clear();
    for (int i = 0; i < arguments.length; i++) {
      ret.add(converters[i].convert(arguments[i].get()));
    }
    return ret;

  }
  /**
  *执行方法打印
  */
  @Override
  public String getDisplayString(String[] children) {
    return getStandardDisplayString("array", children, ",");
  }
}

二、表生成函数:UDTF

接受零个或多个输入参数,然后产生多行或多列输出
当输出结果为多列时,则forward的输入参数为数组
当输出结果为多行时,在process方法执行for循环输出结果

实现方法时候,重写方法最好是 initialize(ObjectInspector[] argOIs),而不是initialize(StructObjectInspector argOIs)

public StructObjectInspector initialize(StructObjectInspector argOIs)
      throws UDFArgumentException {
    List<? extends StructField> inputFields = argOIs.getAllStructFieldRefs();
    ObjectInspector[] udtfInputOIs = new ObjectInspector[inputFields.size()];
    for (int i = 0; i < inputFields.size(); i++) {
      udtfInputOIs[i] = inputFields.get(i).getFieldObjectInspector();
    }
    return initialize(udtfInputOIs);
  }

  /**
   * 最好实现这个方法,因上面一个方法会自动调用下面这个方法
   * 因在有些版本或是其他组件中(如Spark)是采用当前方法初始化的
   */
  @Deprecated
  public StructObjectInspector initialize(ObjectInspector[] argOIs)
      throws UDFArgumentException {
    throw new IllegalStateException("Should not be called directly");
  }

代码如下(示例):


package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.TaskExecutionException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

/**
 * name:函数名称
 * value: 函数执行说明
 */
@Description(name = "explode",
    value = "_FUNC_(a) - separates the elements of array a into multiple rows,"
      + " or the elements of a map into multiple rows and columns ")
public class GenericUDTFExplode extends GenericUDTF {

  private transient ObjectInspector inputOI = null;
  @Override
  public void close() throws HiveException {
  }
  /**
  * 可以校验输入参数个数,和输入参数列数据格式和字段名称
  * args: 为函数输入参数
  * StructObjectInspector:forward方法中定义的数据格式即输出数据。
  */
  @Override
  public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
    if (args.length != 1) {
      throw new UDFArgumentException("explode() takes only one argument");
    }

    ArrayList<String> fieldNames = new ArrayList<String>();
    ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

    switch (args[0].getCategory()) {
    case LIST:
      inputOI = args[0];
      //定义单列
      fieldNames.add("col");
      fieldOIs.add(((ListObjectInspector)inputOI).getListElementObjectInspector());
      break;
    case MAP:
      inputOI = args[0];
      //定义多列
      fieldNames.add("key");
      fieldNames.add("value");
      fieldOIs.add(((MapObjectInspector)inputOI).getMapKeyObjectInspector());
      fieldOIs.add(((MapObjectInspector)inputOI).getMapValueObjectInspector());
      break;
    default:
      throw new UDFArgumentException("explode() takes an array or a map as a parameter");
    }

    return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
        fieldOIs);
  }

  private transient final Object[] forwardListObj = new Object[1];
  private transient final Object[] forwardMapObj = new Object[2];
  /**
  * 具体执行逻辑:o表示函数输入参数
  */
  @Override
  public void process(Object[] o) throws HiveException {
    switch (inputOI.getCategory()) {
    case LIST:
      ListObjectInspector listOI = (ListObjectInspector)inputOI;
      List<?> list = listOI.getList(o[0]);
      if (list == null) {
        return;
      }
      for (Object r : list) {
        forwardListObj[0] = r;
        forward(forwardListObj);
      }
      break;
    case MAP:
      MapObjectInspector mapOI = (MapObjectInspector)inputOI;
      Map<?,?> map = mapOI.getMap(o[0]);
      if (map == null) {
        return;
      }
      for (Entry<?,?> r : map.entrySet()) {
        forwardMapObj[0] = r.getKey();
        forwardMapObj[1] = r.getValue();
        //输出结果
        forward(forwardMapObj);
      }
      break;
    default:
      throw new TaskExecutionException("explode() can only operate on an array or a map");
    }
  }

  @Override
  public String toString() {
    return "explode";
  }
}

三、聚合函数UDAF

接受零行或是多行的零列到多列数据,返回一个值。

Mode.PARTIAL1 从原始数据到部分聚合数据
Mode.PARTIAL2 从部分聚合数据到部分聚合数据
Mode.FINAL 从部分聚合数据到完全聚合数据
Mode.COMPLETE 从原始数据直接到完全聚合数据

实现代码

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.HashSet;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorObject;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

/**
 * name: 函数名称
 * value: 函数说明
 */
@Description(name = "count",
    value = "_FUNC_(*) - Returns the total number of retrieved rows, including "
          +        "rows containing NULL values.\n"

          + "_FUNC_(expr) - Returns the number of rows for which the supplied "
          +        "expression is non-NULL.\n"

          + "_FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for "
          +        "which the supplied expression(s) are unique and non-NULL.")
public class GenericUDAFCount implements GenericUDAFResolver2 {
  /**
  *返回当前函数的具体执行类,保留此方法向后兼容
  */
  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
      throws SemanticException {
    // This method implementation is preserved for backward compatibility.
    return new GenericUDAFCountEvaluator();
  }
  /**
  * 返回当前函数具体执行类,并对类的输入参数进行校验
  */
  @Override
  public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo paramInfo)
  throws SemanticException {
	//所有参数类型信息
    TypeInfo[] parameters = paramInfo.getParameters();

    if (parameters.length == 0) {
      if (!paramInfo.isAllColumns()) {
        throw new UDFArgumentException("Argument expected");
      }
      assert !paramInfo.isDistinct() : "DISTINCT not supported with *";
    } else {
      if (parameters.length > 1 && !paramInfo.isDistinct()) {
        throw new UDFArgumentException("DISTINCT keyword must be specified");
      }
      assert !paramInfo.isAllColumns() : "* not supported in expression list";
    }

    GenericUDAFCountEvaluator countEvaluator = new GenericUDAFCountEvaluator();
    countEvaluator.setWindowing(paramInfo.isWindowing());
    countEvaluator.setCountAllColumns(paramInfo.isAllColumns());
    countEvaluator.setCountDistinct(paramInfo.isDistinct());

    return countEvaluator;
  }

  /**
   * 具体执行方法
   */
  public static class GenericUDAFCountEvaluator extends GenericUDAFEvaluator {
    private boolean isWindowing = false;
    private boolean countAllColumns = false;
    private boolean countDistinct = false;
    //计算过程中产生的中间结果数据类型
    private LongObjectInspector partialCountAggOI;
    //
    private ObjectInspector[] inputOI, outputOI;
    //最终函数输入结果数据类型
    private LongWritable result;

    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
    throws HiveException {
      super.init(m, parameters);
      //执行 merge() and  terminate()
      if (mode == Mode.PARTIAL2 || mode == Mode.FINAL) {
        partialCountAggOI = (LongObjectInspector)parameters[0];
      } else {
      //执行 iterate() and terminate()方法 
        inputOI = parameters;
        outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI,
            ObjectInspectorCopyOption.JAVA);
      }
      result = new LongWritable(0);

      return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
    }

    public void setWindowing(boolean isWindowing) {
      this.isWindowing = isWindowing;
    }

    private void setCountAllColumns(boolean countAllCols) {
      countAllColumns = countAllCols;
    }

    private void setCountDistinct(boolean countDistinct) {
      this.countDistinct = countDistinct;
    }

    private boolean isWindowingDistinct() {
      return isWindowing && countDistinct;
    }

    /** 这里是合并缓存数据类 */
    @AggregationType(estimable = true)
    static class CountAgg extends AbstractAggregationBuffer {
      HashSet<ObjectInspectorObject> uniqueObjects; // Unique rows
      long value;
      @Override
      public int estimate() { return JavaDataModel.PRIMITIVES2; }
    }
	/**
	* 获取新的合并缓存数据
	*/
    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      CountAgg buffer = new CountAgg();
      reset(buffer);
      return buffer;
    }
	/**
	*重置缓存数据
	*/
    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      ((CountAgg) agg).value = 0;
      ((CountAgg) agg).uniqueObjects = new HashSet<ObjectInspectorObject>();
    }
	/**
	* 将原始数据合并成中间缓存数据
	* agg:中间缓存数据
	* parameters: 原始输入数据
	*/
    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
      throws HiveException {
      // parameters == null means the input table/split is empty
      if (parameters == null) {
        return;
      }
      if (countAllColumns) {
        assert parameters.length == 0;
        ((CountAgg) agg).value++;
      } else {
        boolean countThisRow = true;
        for (Object nextParam : parameters) {
          if (nextParam == null) {
            countThisRow = false;
            break;
          }
        }

        // Skip the counting if the values are the same for windowing COUNT(DISTINCT) case
        if (countThisRow && isWindowingDistinct()) {
          HashSet<ObjectInspectorObject> uniqueObjs = ((CountAgg) agg).uniqueObjects;
          ObjectInspectorObject obj = new ObjectInspectorObject(
              ObjectInspectorUtils.copyToStandardObject(parameters, inputOI, ObjectInspectorCopyOption.JAVA),
              outputOI);
          if (!uniqueObjs.contains(obj)) {
            uniqueObjs.add(obj);
          } else {
            countThisRow = false;
          }
        }

        if (countThisRow) {
          ((CountAgg) agg).value++;
        }
      }
    }
    /**
	* 得到部分聚合结果
	*/
    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      if (isWindowingDistinct()) {
        throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial");
      } else {
        return terminate(agg);
      }
    }
	/**
	* 将中间结果合并
	* agg: 中间结果缓存
	* partial: 部分聚合结果
	*/
    @Override
    public void merge(AggregationBuffer agg, Object partial)
      throws HiveException {
      if (partial != null) {
        CountAgg countAgg = (CountAgg) agg;

        if (isWindowingDistinct()) {
          throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial");
        } else {
          long p = partialCountAggOI.get(partial);
          countAgg.value += p;
        }
      }
    }
	/**
	* 得到最终聚合结果
	*/
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      result.set(((CountAgg) agg).value);
      return result;
    }
	
  }
}


总结

在使用自定义聚合函数需要注意MODE的类型,特别是在init中。需要知道在每一次执行的时候数据类型。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值