文章目录
前言
自定义函数是为了扩充HIVE查询中的函数,自定义函数可以创建为临时函数、永久函数。
1.引入库
下面所有章节需要用到的引入jar包
<dependency>
<group>org.apache.hive</group>
<artifactId>hive-exec</artifactId>
<version>2.3.9</version>
</dependency>
<dependency>
<group>org.apache.hive</group>
<artifactId>hive-jdbc</artifactId>
<version>2.3.9</version>
</dependency>
2.通过命令查询函数
查询所有函数
SHOW FUNCTIONS;
+---------------------------+
| tab_name |
+---------------------------+
| avg |
| uuid |
| sum |
+---------------------------+
通过模糊匹配查询函数
SHOW FUNCTIONS LIKE ‘%avg%’;
+---------------------------+
| tab_name |
+---------------------------+
| avg |
+---------------------------+
查询函数简短说明
DESCRIBE FUNCTION avg;
+------------------------------------------------+
| tab_name |
+------------------------------------------------+
| agv(x) - Returns the mean of a set of numbers | |
+------------------------------------------------+
#查看函数详细说明
DESCRIBE FUNCTION EXTENDED avg;
+------------------------------------------------+
| tab_name |
+------------------------------------------------+
| agv(x) - Returns the mean of a set of numbers | |
| Function class:org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage |
| Function Type: BUILTIN |
+------------------------------------------------+
3.自定义函数添加到Hive函数库中
添加自定义函数为永久函数
--切换到指定的database中
use test;
--如果存在则删除当前函数
drop function if exists fn_test_avg;
--创建函数:需要指定函数名称、类名称、jar在HDFS的路径
create function fun_test_avg as 'com.study.udaf.TestAVG' using jar 'hdfs://nameservicetentant/user/study/jar/study.jar';
添加自定义函数为临时函数
--切换到指定的database中
use test;
--如果存在则删除当前函数
drop temporary function if exists fn_test_avg;
--创建函数:需要指定函数名称、类名称、jar在HDFS的路径
create temporary function fun_test_avg as 'com.study.udaf.TestAVG' using jar 'hdfs://nameservicetentant/user/study/jar/study.jar';
一、标准函数:UDF
一行数据中的一列或是多列数据作为参数,然后返回结果是一个值的函数。
标准函数的内置函数如:绝对值abs()、字符串反转reverse()。
基于UDF
单个输入参数,且输入参数数据格式和输出参数数据格式一致可以采用
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by Fernflower decompiler)
//
package org.apache.hadoop.hive.ql.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils;
import org.apache.hadoop.io.Text;
/**
* name: 函数名称
* value: 函数返回值
* extended: 函数执行示例
*/
@Description(
name = "reverse",
value = "_FUNC_(str) - reverse str",
extended = "Example:\n > SELECT _FUNC_('Facebook') FROM src LIMIT 1;\n 'koobecaF'"
)
public class UDFReverse extends UDF {
private final Text result = new Text();
public UDFReverse() {
}
private void reverse(byte[] arr, int first, int last) {
for(int i = 0; i < (last - first + 1) / 2; ++i) {
byte temp = arr[last - i];
arr[last - i] = arr[first + i];
arr[first + i] = temp;
}
}
/**
* 当前的方式是通过父类初始化方法,创建了DefaultUDFMethodResolver
* 而执行方法为:中的 evaluate 调用到当前的方法。
* public Method getEvalMethod(List<TypeInfo> argClasses) throws UDFArgumentException {
return FunctionRegistry.getMethodInternal(this.udfClass, "evaluate", false, argClasses);
}
*/
public Text evaluate(Text s) {
if (s == null) {
return null;
} else {
this.result.set(s);
byte[] data = this.result.getBytes();
int prev = 0;
for(int i = 1; i < this.result.getLength(); ++i) {
if (GenericUDFUtils.isUtfStartByte(data[i])) {
this.reverse(data, prev, i - 1);
prev = i;
}
}
this.reverse(data, prev, this.result.getLength() - 1);
this.reverse(data, 0, this.result.getLength() - 1);
return this.result;
}
}
}
GenericUDF单个参数输入
定义函数单个参数输入,单个值输出格式。可以对输入参数格式和输入数据类型校验
package org.apache.hadoop.hive.ql.udf.generic;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsDoubleToDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
/**
* name: 函数名称
* value: 函数返回值
* extended: 函数执行示例
*/
@Description(
name = "abs",
value = "_FUNC_(x) - returns the absolute value of x",
extended = "Example:\n > SELECT _FUNC_(0) FROM src LIMIT 1;\n 0\n > SELECT _FUNC_(-5) FROM src LIMIT 1;\n 5"
)
@VectorizedExpressions({FuncAbsLongToLong.class, FuncAbsDoubleToDouble.class, FuncAbsDecimalToDecimal.class})
public class GenericUDFAbs extends GenericUDF {
private transient PrimitiveCategory inputType;
private final DoubleWritable resultDouble = new DoubleWritable();
private final LongWritable resultLong = new LongWritable();
private final IntWritable resultInt = new IntWritable();
private final HiveDecimalWritable resultDecimal = new HiveDecimalWritable();
private transient PrimitiveObjectInspector argumentOI;
private transient Converter inputConverter;
public GenericUDFAbs() {
}
/**
* 对输入参数的交易,以及输出参数格式的定义
* arguments为函数输入参数
*/
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length != 1) {
throw new UDFArgumentLengthException("ABS() requires 1 argument, got " + arguments.length);
} else if (arguments[0].getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentException("ABS only takes primitive types, got " + arguments[0].getTypeName());
} else {
this.argumentOI = (PrimitiveObjectInspector)arguments[0];
this.inputType = this.argumentOI.getPrimitiveCategory();
ObjectInspector outputOI = null;
switch(this.inputType) {
case SHORT:
case BYTE:
case INT:
this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableIntObjectInspector);
outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
break;
case LONG:
this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableLongObjectInspector);
outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
break;
case FLOAT:
case STRING:
case DOUBLE:
this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
break;
case DECIMAL:
outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(((PrimitiveObjectInspector)arguments[0]).getTypeInfo());
this.inputConverter = ObjectInspectorConverters.getConverter(arguments[0], (ObjectInspector)outputOI);
break;
default:
throw new UDFArgumentException("ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + this.inputType);
}
return (ObjectInspector)outputOI;
}
}
/**
* 执行方法,arguments为方法中输入的参数,
* 注意返回值要和(initialize)定义的一致
*/
public Object evaluate(DeferredObject[] arguments) throws HiveException {
Object valObject = arguments[0].get();
if (valObject == null) {
return null;
} else {
switch(this.inputType) {
case SHORT:
case BYTE:
case INT:
valObject = this.inputConverter.convert(valObject);
this.resultInt.set(Math.abs(((IntWritable)valObject).get()));
return this.resultInt;
case LONG:
valObject = this.inputConverter.convert(valObject);
this.resultLong.set(Math.abs(((LongWritable)valObject).get()));
return this.resultLong;
case FLOAT:
case STRING:
case DOUBLE:
valObject = this.inputConverter.convert(valObject);
if (valObject == null) {
return null;
}
this.resultDouble.set(Math.abs(((DoubleWritable)valObject).get()));
return this.resultDouble;
case DECIMAL:
HiveDecimalObjectInspector decimalOI = (HiveDecimalObjectInspector)this.argumentOI;
HiveDecimalWritable val = decimalOI.getPrimitiveWritableObject(valObject);
if (val != null) {
this.resultDecimal.set(val);
this.resultDecimal.mutateAbs();
val = this.resultDecimal;
}
return val;
default:
throw new UDFArgumentException("ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + this.inputType);
}
}
}
/**
* 当前方法执行,输出为字符串
*/
@Override
public String getDisplayString(String[] children) {
return this.getStandardDisplayString("abs", children);
}
}
GenericUDF多个参数输入
函数多个参数输入,单个输出结果格式定义。可以对输入参数个数和数据类型校验
package org.apache.hadoop.hive.ql.udf.generic;
import java.util.ArrayList;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.VoidObjectInspector;
/**
* name: 函数名称
* value: 函数返回值
*/
@Description(name = "array",
value = "_FUNC_(n0, n1...) - Creates an array with the given elements ")
public class GenericUDFArray extends GenericUDF {
private transient Converter[] converters;
private transient ArrayList<Object> ret = new ArrayList<Object>();
/**
* 函数输出参数初始化,其中arguments为函数输入参数
* 多个输入参数,转换为一个输出参数
*/
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(true);
for (int i = 0; i < arguments.length; i++) {
if (!returnOIResolver.update(arguments[i])) {
throw new UDFArgumentTypeException(i, "Argument type \""
+ arguments[i].getTypeName()
+ "\" is different from preceding arguments. "
+ "Previous type was \"" + arguments[i - 1].getTypeName() + "\"");
}
}
converters = new Converter[arguments.length];
ObjectInspector returnOI =
returnOIResolver.get(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
for (int i = 0; i < arguments.length; i++) {
converters[i] = ObjectInspectorConverters.getConverter(arguments[i],
returnOI);
}
return ObjectInspectorFactory.getStandardListObjectInspector(returnOI);
}
/**
* 函数执行输入参数 arguments
* 返回的Object要对应(initialize)中返回的ObjectInspector。
*/
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
ret.clear();
for (int i = 0; i < arguments.length; i++) {
ret.add(converters[i].convert(arguments[i].get()));
}
return ret;
}
/**
*执行方法打印
*/
@Override
public String getDisplayString(String[] children) {
return getStandardDisplayString("array", children, ",");
}
}
二、表生成函数:UDTF
接受零个或多个输入参数,然后产生多行或多列输出
当输出结果为多列时,则forward的输入参数为数组
当输出结果为多行时,在process方法执行for循环输出结果
实现方法时候,重写方法最好是
initialize(ObjectInspector[] argOIs)
,而不是initialize(StructObjectInspector argOIs)
public StructObjectInspector initialize(StructObjectInspector argOIs)
throws UDFArgumentException {
List<? extends StructField> inputFields = argOIs.getAllStructFieldRefs();
ObjectInspector[] udtfInputOIs = new ObjectInspector[inputFields.size()];
for (int i = 0; i < inputFields.size(); i++) {
udtfInputOIs[i] = inputFields.get(i).getFieldObjectInspector();
}
return initialize(udtfInputOIs);
}
/**
* 最好实现这个方法,因上面一个方法会自动调用下面这个方法
* 因在有些版本或是其他组件中(如Spark)是采用当前方法初始化的
*/
@Deprecated
public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
throw new IllegalStateException("Should not be called directly");
}
代码如下(示例):
package org.apache.hadoop.hive.ql.udf.generic;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.TaskExecutionException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
/**
* name:函数名称
* value: 函数执行说明
*/
@Description(name = "explode",
value = "_FUNC_(a) - separates the elements of array a into multiple rows,"
+ " or the elements of a map into multiple rows and columns ")
public class GenericUDTFExplode extends GenericUDTF {
private transient ObjectInspector inputOI = null;
@Override
public void close() throws HiveException {
}
/**
* 可以校验输入参数个数,和输入参数列数据格式和字段名称
* args: 为函数输入参数
* StructObjectInspector:forward方法中定义的数据格式即输出数据。
*/
@Override
public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
if (args.length != 1) {
throw new UDFArgumentException("explode() takes only one argument");
}
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
switch (args[0].getCategory()) {
case LIST:
inputOI = args[0];
//定义单列
fieldNames.add("col");
fieldOIs.add(((ListObjectInspector)inputOI).getListElementObjectInspector());
break;
case MAP:
inputOI = args[0];
//定义多列
fieldNames.add("key");
fieldNames.add("value");
fieldOIs.add(((MapObjectInspector)inputOI).getMapKeyObjectInspector());
fieldOIs.add(((MapObjectInspector)inputOI).getMapValueObjectInspector());
break;
default:
throw new UDFArgumentException("explode() takes an array or a map as a parameter");
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
fieldOIs);
}
private transient final Object[] forwardListObj = new Object[1];
private transient final Object[] forwardMapObj = new Object[2];
/**
* 具体执行逻辑:o表示函数输入参数
*/
@Override
public void process(Object[] o) throws HiveException {
switch (inputOI.getCategory()) {
case LIST:
ListObjectInspector listOI = (ListObjectInspector)inputOI;
List<?> list = listOI.getList(o[0]);
if (list == null) {
return;
}
for (Object r : list) {
forwardListObj[0] = r;
forward(forwardListObj);
}
break;
case MAP:
MapObjectInspector mapOI = (MapObjectInspector)inputOI;
Map<?,?> map = mapOI.getMap(o[0]);
if (map == null) {
return;
}
for (Entry<?,?> r : map.entrySet()) {
forwardMapObj[0] = r.getKey();
forwardMapObj[1] = r.getValue();
//输出结果
forward(forwardMapObj);
}
break;
default:
throw new TaskExecutionException("explode() can only operate on an array or a map");
}
}
@Override
public String toString() {
return "explode";
}
}
三、聚合函数UDAF
接受零行或是多行的零列到多列数据,返回一个值。
Mode.PARTIAL1 从原始数据到部分聚合数据
Mode.PARTIAL2 从部分聚合数据到部分聚合数据
Mode.FINAL 从部分聚合数据到完全聚合数据
Mode.COMPLETE 从原始数据直接到完全聚合数据
实现代码
package org.apache.hadoop.hive.ql.udf.generic;
import java.util.HashSet;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorObject;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
/**
* name: 函数名称
* value: 函数说明
*/
@Description(name = "count",
value = "_FUNC_(*) - Returns the total number of retrieved rows, including "
+ "rows containing NULL values.\n"
+ "_FUNC_(expr) - Returns the number of rows for which the supplied "
+ "expression is non-NULL.\n"
+ "_FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for "
+ "which the supplied expression(s) are unique and non-NULL.")
public class GenericUDAFCount implements GenericUDAFResolver2 {
/**
*返回当前函数的具体执行类,保留此方法向后兼容
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
// This method implementation is preserved for backward compatibility.
return new GenericUDAFCountEvaluator();
}
/**
* 返回当前函数具体执行类,并对类的输入参数进行校验
*/
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo paramInfo)
throws SemanticException {
//所有参数类型信息
TypeInfo[] parameters = paramInfo.getParameters();
if (parameters.length == 0) {
if (!paramInfo.isAllColumns()) {
throw new UDFArgumentException("Argument expected");
}
assert !paramInfo.isDistinct() : "DISTINCT not supported with *";
} else {
if (parameters.length > 1 && !paramInfo.isDistinct()) {
throw new UDFArgumentException("DISTINCT keyword must be specified");
}
assert !paramInfo.isAllColumns() : "* not supported in expression list";
}
GenericUDAFCountEvaluator countEvaluator = new GenericUDAFCountEvaluator();
countEvaluator.setWindowing(paramInfo.isWindowing());
countEvaluator.setCountAllColumns(paramInfo.isAllColumns());
countEvaluator.setCountDistinct(paramInfo.isDistinct());
return countEvaluator;
}
/**
* 具体执行方法
*/
public static class GenericUDAFCountEvaluator extends GenericUDAFEvaluator {
private boolean isWindowing = false;
private boolean countAllColumns = false;
private boolean countDistinct = false;
//计算过程中产生的中间结果数据类型
private LongObjectInspector partialCountAggOI;
//
private ObjectInspector[] inputOI, outputOI;
//最终函数输入结果数据类型
private LongWritable result;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
//执行 merge() and terminate()
if (mode == Mode.PARTIAL2 || mode == Mode.FINAL) {
partialCountAggOI = (LongObjectInspector)parameters[0];
} else {
//执行 iterate() and terminate()方法
inputOI = parameters;
outputOI = ObjectInspectorUtils.getStandardObjectInspector(inputOI,
ObjectInspectorCopyOption.JAVA);
}
result = new LongWritable(0);
return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
}
public void setWindowing(boolean isWindowing) {
this.isWindowing = isWindowing;
}
private void setCountAllColumns(boolean countAllCols) {
countAllColumns = countAllCols;
}
private void setCountDistinct(boolean countDistinct) {
this.countDistinct = countDistinct;
}
private boolean isWindowingDistinct() {
return isWindowing && countDistinct;
}
/** 这里是合并缓存数据类 */
@AggregationType(estimable = true)
static class CountAgg extends AbstractAggregationBuffer {
HashSet<ObjectInspectorObject> uniqueObjects; // Unique rows
long value;
@Override
public int estimate() { return JavaDataModel.PRIMITIVES2; }
}
/**
* 获取新的合并缓存数据
*/
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
CountAgg buffer = new CountAgg();
reset(buffer);
return buffer;
}
/**
*重置缓存数据
*/
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((CountAgg) agg).value = 0;
((CountAgg) agg).uniqueObjects = new HashSet<ObjectInspectorObject>();
}
/**
* 将原始数据合并成中间缓存数据
* agg:中间缓存数据
* parameters: 原始输入数据
*/
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
// parameters == null means the input table/split is empty
if (parameters == null) {
return;
}
if (countAllColumns) {
assert parameters.length == 0;
((CountAgg) agg).value++;
} else {
boolean countThisRow = true;
for (Object nextParam : parameters) {
if (nextParam == null) {
countThisRow = false;
break;
}
}
// Skip the counting if the values are the same for windowing COUNT(DISTINCT) case
if (countThisRow && isWindowingDistinct()) {
HashSet<ObjectInspectorObject> uniqueObjs = ((CountAgg) agg).uniqueObjects;
ObjectInspectorObject obj = new ObjectInspectorObject(
ObjectInspectorUtils.copyToStandardObject(parameters, inputOI, ObjectInspectorCopyOption.JAVA),
outputOI);
if (!uniqueObjs.contains(obj)) {
uniqueObjs.add(obj);
} else {
countThisRow = false;
}
}
if (countThisRow) {
((CountAgg) agg).value++;
}
}
}
/**
* 得到部分聚合结果
*/
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
if (isWindowingDistinct()) {
throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial");
} else {
return terminate(agg);
}
}
/**
* 将中间结果合并
* agg: 中间结果缓存
* partial: 部分聚合结果
*/
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
CountAgg countAgg = (CountAgg) agg;
if (isWindowingDistinct()) {
throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial");
} else {
long p = partialCountAggOI.get(partial);
countAgg.value += p;
}
}
}
/**
* 得到最终聚合结果
*/
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
result.set(((CountAgg) agg).value);
return result;
}
}
}
总结
在使用自定义聚合函数需要注意MODE的类型,特别是在init中。需要知道在每一次执行的时候数据类型。