什么是udf?
UDF(User-Defined Functions)即是用户自定义的hive函数。hive自带的函数并不能完全满足业务的需求,这时就需要我们自定义函数了。
udf的分类
udf:输入一条数据输出一条数据,相当于substr()函数;
udaf:输入多条输出一条,相当于聚合函数,count();
udtf:输入一条,输出多条,如lateral view 与 explode
UDF(单输入单输出)
继承自org.apache.hadoop.hive.ql.exec.UDF
- 只能对基础数据类型做处理:byte(位)、short(短整数)、int(整数)、long(长整数)、float(单精度)、double(双精度)、char(字符)和boolean(布尔值)
- 只要实现evaluate()方法就可以
示例:
import org.apache.hadoop.hive.ql.exec.UDF;
public class StrDemo extends UDF {
public String evaluate(String line) {
StringBuilder sb = new StringBuilder();
sb.append(line).append("_test") ;
return sb.toString();
}
}
继承自org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
- 可以处理、返回复杂的数据类型:map,struct,array等
- 实现:必须重写3个重要方法
ObjectInspector initialize(ObjectInspector[] arguments)
//初始化操作,只会在evaluate方法前被调用一次,检验并定义返回值的类型
Object evaluate(DeferredObject[] arguments)
//进行业务逻辑计算,处理具体数据
String getDisplayString(String[] children)
//显示函数的帮助信息
示例:
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BooleanWritable;
// 注释 : 数组是否包含某个值
@Description(name = "array_contains", value = "_FUNC_(array, value) - Returns TRUE if the array contains value.", extended = "Example:\n > SELECT _FUNC_(array(1, 2, 3), 2) FROM src LIMIT 1;\n true")
public class ArrayContains extends GenericUDF {
// 参数 2 类型
private transient ObjectInspector valueOI;
// 参数 1 类型 :集合类型
private transient ListObjectInspector arrayOI;
// 集合中的元素类型
private transient ObjectInspector arrayElementOI;
// 返回值类型
private BooleanWritable result;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// 校验参数个数
if (arguments.length != 2) {
throw new UDFArgumentException("we need 2 args.");
}
// 校验参数类型 ( 第一个参数,需要是一个集合 )
if (!(arguments[0].getCategory().equals(ObjectInspector.Category.LIST))) {
throw new UDFArgumentTypeException(0, "arg[0] need to be list type");
}
// 参数 1 类型 : 集合乐行
this.arrayOI = ((ListObjectInspector) arguments[0]);
// 集合中的元素的类型
this.arrayElementOI = this.arrayOI.getListElementObjectInspector();
// 参数 2 类型
this.valueOI = arguments[1];
// 校验 : 参数 2 必须和集合中的元素的类型是同一类型
if (!(ObjectInspectorUtils.compareTypes(this.arrayElementOI, this.valueOI))) {
throw new UDFArgumentTypeException(1,"list arg type must be same type with arg1 type");
}
// 校验 : 参数 2 必须是 hive 支持的类型
if (!(ObjectInspectorUtils.compareSupported(this.valueOI))) {
throw new UDFArgumentException("we don't support the type");
}
// 返回值类型:boolean
this.result = new BooleanWritable(false);
// 确定返回值类型:boolean
return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector;
}
@Override // 逻辑代码
public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {
// 默认返回 false
this.result.set(false);
// 取出参数,默认 Object 类型
Object array = arguments[0].get();
Object value = arguments[1].get();
// 集合长度
int arrayLength = this.arrayOI.getListLength(array);
// 如果集合中没有元素,或者被比较的元素是 null ,直接返回 false ;
if ((value == null) || (arrayLength <= 0)) {
return this.result;
}
// 比较,逻辑处理
for (int i = 0; i < arrayLength; ++i) {
Object listElement = this.arrayOI.getListElement(array, i);
// 方法解释 : ObjectInspectorUtils.compare(value, this.valueOI, listElement, this.arrayElementOI)
// 四个参数 : 第一个值,类型,第二个值,类型
// 逻辑解释 : 比较不一样,就下一组继续,比较一样,就这个了,赋值为 true ,方法结束
if ((listElement == null)|| (ObjectInspectorUtils.compare(value, this.valueOI, listElement, this.arrayElementOI) != 0)){
continue;
}else{
this.result.set(true);
break;
}
}
return this.result;
}
public String getDisplayString(String[] children) {
assert (children.length == 2);
return "array_contains(" + children[0] + ", " + children[1] + ")";
}
UDAF(多输入一输出)
一般涉及的两个抽象类:
org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
用于在hive中注册UDAF,里面会实例化FieldLengthUDAFEvaluator,该类需继承AbstractGenericUDAFResolver
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
用于实现四个阶段中会被调用的方法,该类需继承GenericUDAFEvaluator;
为了更好理解上述抽象类的API,要记住hive只是mapreduce函数,只不过hive已经帮助我们写好并隐藏mapreduce,向上提供简洁的sql函数,所以我们要结合Mapper、Combiner与Reducer来帮助我们理解这个函数。要记住在hadoop集群中有若干台机器,在不同的机器上Mapper与Reducer任务独立运行。
所以大体上来说,这个UDAF函数读取数据(mapper),聚集一堆mapper输出到部分聚集结果(combiner),并且最终创建一个最终的聚集结果(reducer)。因为我们跨域多个combiner进行聚集,所以我们需要保存部分聚集结果。
内部类中必须实现的6个方法:
1、getNewAggregationBuffer():返回存储临时聚合结果的AggregationBuffer对象。
2、reset(AggregationBuffer agg):重置聚合结果,以支持mapper和reducer的重用。
3、iterate(AggregationBuffer agg,Object[] parameters):map阶段,迭代处理输入sql传过来的列数据。
4、terminatePartial(AggregationBuffer agg):map与combiner结束返回结果,得到部分数据聚集结果,以持久化的方式返回部分聚合结果,类似于 MapReduce的combiner
5、merge(AggregationBuffer agg,Object partial):接受来自 terminatePartial的返回结果,进行合并,hive合并两部分聚合的时候回调用这个方法
6、terminate(AggregationBuffer agg):返回最终结果。
还有一个初始化方法
ObjectInspector init : 确定各个阶段输入输出参数的数据格式ObjectInspectors
2.代码结构:
1)需继承AbstractGenericUDAFResolver抽象类,重写方法getEvaluator(TypeInfo[] parameters);
2)内部静态类需继承GenericUDAFEvaluator抽象类,重写方法init(),实现方法getNewAggregationBuffer(),reset(),iterate(),terminatePartial(),merge(),terminate()。
3.程序执行过程:
- PARTIAL1:从原始数据到部分聚合数据的过程,会调用 iterate() 和 terminatePartial() 方法。iterate() 函数负责解析输入数据,terminatePartial() 负责输出当前临时聚合结果。该阶段可以理解为对应 MapReduce 过程中的 Map 阶段。
- PARTIAL2:从部分聚合数据到部分聚合数据的过程(多次聚合),会调用 merge() 和 terminatePartial() 方法。merge() 函数负责聚合 Map 阶段 terminatePartial() 函数输出的部分聚合结果,terminatePartial() 负责输出当前临时聚合结果。阶段可以理解为对应 MapReduce 过程中的 Combine 阶段。
- FINAL: 从部分聚合数据到全部聚合数据的过程,会调用 merge() 和 terminate() 方法。merge() 函数负责聚合 Map 阶段或者 Combine 阶段 terminatePartial() 函数输出的部分聚合结果。terminate() 方法负责输出 Reduce 阶段最终的聚合结果。该阶段可以理解为对应 MapReduce 过程中的 Reduce 阶段。
- COMPLETE: 从原始数据直接到全部聚合数据的过程,会调用 iterate() 和 terminate() 方法。可以理解为 MapReduce 过程中的直接 Map 输出阶段,没有 Reduce 阶段。
@Description(name = "letters", value = "_FUNC_(expr) - 返回该列中所有字符串的字符总数")
public class TotalNumOfLettersGenericUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentTypeException(0,
"Argument must be PRIMITIVE, but "
+ oi.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING){
throw new UDFArgumentTypeException(0,
"Argument must be String, but "
+ inputOI.getPrimitiveCategory().name()
+ " was passed.");
}
return new TotalNumOfLettersEvaluator();
}
public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {
PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;
int total = 0;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(m, parameters);
//map阶段读取sql列,输入为String基础数据格式
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
//其余阶段,输入为Integer基础数据格式
integerOI = (PrimitiveObjectInspector) parameters[0];
}
// 指定各个阶段输出数据格式都为Integer类型
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
ObjectInspectorOptions.JAVA);
return outputOI;
}
/**
* 存储当前字符总数的类
*/
static class LetterSumAgg implements AggregationBuffer {
int sum = 0;
void add(int num){
sum += num;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LetterSumAgg result = new LetterSumAgg();
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = new LetterSumAgg();
}
private boolean warned = false;
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
if (parameters[0] != null) {
LetterSumAgg myagg = (LetterSumAgg) agg;
Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
myagg.add(String.valueOf(p1).length());
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total += myagg.sum;
return total;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
LetterSumAgg myagg1 = (LetterSumAgg) agg;
Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
LetterSumAgg myagg2 = new LetterSumAgg();
myagg2.add(partialSum);
myagg1.add(myagg2.sum);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = (LetterSumAgg) agg;
total = myagg.sum;
return myagg.sum;
}
}
UDTF (单输入多输出)
继承GenericUDTF
需要实现的方法void close()
StructObjectInspector initialize(ObjectInspector[] args)
//定义输出数据的列名和数据类型
void process(Object[] args)
//对数据进行操作
例子:根据传入的参数,将数据切分为多行
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import java.util.ArrayList;
import java.util.List;
public class MyUDTF extends GenericUDTF {
private ArrayList<String> outList = new ArrayList<>();
@Override
public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {
//1.定义输出数据的列名和类型
List<String> fieldNames = new ArrayList<>();
List<ObjectInspector> fieldOIs = new ArrayList<>();
//2.添加输出数据的列名和类型
fieldNames.add("lineToWord");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
//1.获取原始数据
String arg = args[0].toString();
//2.获取数据传入的第二个参数,此处为分隔符
String splitKey = args[1].toString();
//3.将原始数据按照传入的分隔符进行切分
String[] fields = arg.split(splitKey);
//4.遍历切分后的结果,并写出
for (String field : fields) {
//集合为复用的,首先清空集合
outList.clear();
//将每一个单词添加至集合
outList.add(field);
//将集合内容写出
forward(outList);
}
}
@Override
public void close() throws HiveException {
注册&使用
hive (default)> add jar /opt/module/hive/lib/hive-demo-1.0-SNAPSHOT.jar;
Added [/opt/module/hive/lib/hive-demo-1.0-SNAPSHOT.jar] to class path
Added resources: [/opt/module/hive/lib/hive-demo-1.0-SNAPSHOT.jar]
hive (default)> create temporary function my_len as "com.atguigu.udf.MyUDF";
hive (default)> select my_len('zhang') ;