Hive有两个不同的接口编写UDF程序。一个是操作简单数据类型的UDF接口,一个是操作复杂类型的GenericUDF接口。
UDF接口暂不讨论。
下面看一个例子:
ary建表语句:
create table ary(
array1 array<struct<id:string,name:string>>,
array2 array<struct<addr:string,id:string,dt:String>>
)
row format delimited fields terminated by '\t'
map keys terminated by ',';
ary表中数据如下:
hive > select * from ary;
OK
ary.array1 ary.array2
[{"id":"001","name":"明明"}] [{"addr":"杨浦","id":"001","dt":"20180703"}]
[{"id":"002","name":"阿达"}] [{"addr":"黄浦","id":"004","dt":"20180629"}]
[{"id":"003","name":"阿珂"}] [{"addr":"浦东","id":"003","dt":"20180817"}]
[{"id":"004","name":"小宝"}] [{"addr":"松江","id":"002","dt":"20180623"}]
我们需要的结果:
hive > select aa(array1,array2,'id') from ary;
OK
_c0
[{"id":"001","name":"明明","addr":"杨浦","dt":"20180703"}]
[{"id":"NULL","name":"NULL","addr":"NULL","dt":"NULL"}]
[{"id":"003","name":"阿珂","addr":"浦东","dt":"20180817"}]
[{"id":"NULL","name":"NULL","addr":"NULL","dt":"NULL"}]
函数逻辑:
操作一行数据,通过传入的字段名对两个结构体类型的数组中字段值进行匹配,值相等则返回匹配到的全部值(当然id只匹配一个),否则全部置空
编写UDF自定义函数:
package com.paic.gbd.udfarray;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.logging.Log;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.lazy.LazyString;
import org.apache.hadoop.hive.serde2.lazy.LazyStruct;
import org.apache.hadoop.hive.serde2.lazy.LazyMap;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import java.util.ArrayList;
import java.util.List;
public class GenericArrayUnion extends GenericUDF {
//定义输出日志,方便调试,在hive.log里面查看,该文件目录在hive-log4j.properties配置文件里查看
static final Log LOG = LogFactory.getLog(GenericArrayUnion.class.getName());
输入变量定义
private static final int ARG_COUNT = 3; // Number of arguments to this UDF
private static final String FUNC_NAME = "array_union"; // External Name
private ListObjectInspector arrayOI;
private ListObjectInspector arrayOI2;
private ObjectInspector strOI;
private StructObjectInspector structOI;
private StructObjectInspector structOI2;
private ArrayList<Object> valueList = new ArrayList<Object>();
private ArrayList<Object> valueList2 = new ArrayList<Object>();
int num1=0;
int num2=0;
int valueListLength=0;
int valueList2Length=0;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
arrayOI = (ListObjectInspector)arguments[0];
arrayOI2 = (ListObjectInspector)arguments[1];
strOI = (ObjectInspector)arguments[2];
structOI = ((StructObjectInspector)arrayOI.getListElementObjectInspector());
structOI2 = ((StructObjectInspector)arrayOI2.getListElementObjectInspector());
//判断参数个数
if (arguments.length != ARG_COUNT) {
throw new UDFArgumentException("The function " + FUNC_NAME
+ " accepts " + ARG_COUNT + " arguments.");
}
//判断参数1,2的数据类型
if (!arguments[0].getCategory().equals(Category.LIST) || !arguments[1].getCategory().equals(Category.LIST)) {
throw new UDFArgumentException(
"\""
+ org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME
+ "\" "
+ "expected at function arguments, but \"" + arguments[0].getTypeName()
+ "\" and \"" + arguments[1].getTypeName() + "\" is found.");
}
//数组中结构体结构必须完全一致
if ( structOI.getCategory() != Category.STRUCT
|| structOI2.getCategory() != Category.STRUCT) {
throw new UDFArgumentException(
"\""
+ org.apache.hadoop.hive.serde.serdeConstants.STRUCT_TYPE_NAME
+ "\" "
+ "expected at ARRAY elements type, but \"" + structOI.getCategory()
+ "\" and \"" + structOI2.getCategory() + "\" is found.");
}
//判断参数3的数据类型
if(!arguments[2].getCategory().equals(Category.PRIMITIVE)) {
throw new UDFArgumentException(
"\""
+ org.apache.hadoop.hive.serde.serdeConstants.PrimitiveTypes
+ "\" "
+ "expected at function arguments, but \"" + arguments[2].getTypeName()
+ "\" is found.");
}
//输出结构体类型的数组定义
ArrayList structFieldNames = new ArrayList();
ArrayList structFieldObjectInspectors = new ArrayList();
String value1=null;
String value2=null;
valueListLength=0;
int strOINum1=0;
for(int i=0;i<structOI.getAllStructFieldRefs().size();i++) {
valueListLength++;
value1=structOI.getAllStructFieldRefs().get(i).toString().substring(2);
structFieldNames.add(value1);
if(ObjectInspectorUtils.getWritableConstantValue(arguments[2]).toString().equals(value1)) {
//定位关联字段在第一个数组中位置
num1=i;
strOINum1++;
}
}
if(strOINum1!=1) {
throw new UDFArgumentException(
"第一个数组中必须有一个传参字段"
);
}
valueList2Length=0;
int strOINum2=0;
for(int i=0;i<structOI2.getAllStructFieldRefs().size();i++) {
valueList2Length++;
value2=structOI2.getAllStructFieldRefs().get(i).toString().substring(2);
if(!ObjectInspectorUtils.getWritableConstantValue(arguments[2]).toString().equals(value2)) {
structFieldNames.add(value2);
}else {
//定位关联字段在第二个数组中位置
num2=i;
strOINum2++;
}
}
if(strOINum2!=1) {
throw new UDFArgumentException(
"第二个数组中必须有一个传参字段"
);
}
for(int i=0;i<structFieldNames.size();i++) {
structFieldObjectInspectors.add( PrimitiveObjectInspectorFactory.writableStringObjectInspector );
}
StructObjectInspector si2;
si2 = ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldObjectInspectors);
return ObjectInspectorFactory.getStandardListObjectInspector(si2);
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException{
LOG.info("hello1");
LOG.info(num1);LOG.info(num2);
LOG.info(valueListLength);LOG.info(valueList2Length);
int arr_len = arrayOI.getListLength(arguments[0].get());
valueList.clear();
valueList2.clear();
//遍历array
String[] value;
value=new String[valueListLength+valueList2Length-1];
LOG.info("hello2");
LOG.info(arr_len);
for(int i=0;i<arr_len;i++)
{ //获取第一个数组值
LazyStruct LStruct = (LazyStruct)arrayOI.getListElement(arguments[0].get(),i);
//获取struct中的所有value值
StructObjectInspector s_arrayElementOI = (StructObjectInspector) structOI;
List field_value = s_arrayElementOI.getStructFieldsDataAsList(LStruct);
valueList.addAll(field_value);
//获取第二个数组值
LazyStruct LStruct2 = (LazyStruct)arrayOI2.getListElement(arguments[1].get(),i);
//获取struct中的所有value值
StructObjectInspector s_arrayElementOI2 = (StructObjectInspector) structOI2;
List field_value2 = s_arrayElementOI2.getStructFieldsDataAsList(LStruct2);
valueList2.addAll(field_value2);
LOG.info("hello3");
//关联字段值相同则添加到输出,否则输出"NULL"
if(valueList.get(num1).toString().equals(valueList2.get(num2).toString())) {
//添加第一个数组值到输出
LOG.info(valueList.size());
LOG.info(valueList2.size());
for(int j=0;j<valueList.size();j++) {
value[j]=valueList.get(j).toString();
LOG.info(value[j]);
}
//添加第二个数组值到输出
for(int j=0,n=0;j<valueList2.size();j++) {
//添加变量n,用value数组索引j+valueList.size()会出现数组越界
if(j!=num2) {
value[n+valueList.size()]=valueList2.get(j).toString();
LOG.info(value[n+valueList.size()]);
n++;
}
}
}else {
for(int j=0;j<value.length;j++) {
value [j]="NULL";
}
}
valueList.clear();
valueList2.clear();
}
LOG.info("hello8");
LOG.info(value.length);
Object[] e;
e = new Object[value.length];
for(int i=0;i<value.length;i++) {
LOG.info(i);
e[i]=new Text(value[i]);
}
ArrayList result = new ArrayList();
result.add(e);
LOG.info("hello10");
return result;
}
@Override
public String getDisplayString(String[] children) {
assert( children.length>0 );
StringBuilder sb = new StringBuilder();
sb.append("array_union(");
sb.append(children[0]);
sb.append(")");
return sb.toString();
}
}
将代码导成jar包上传到linux上,在hive命令行中添加jar包,并创建函数
hive > add jar /home/hadoop/aa.jar;
Added /home/hadoop/aa.jar to class path
Added resource: /home/hadoop/aa.jar
hive > create function aa as 'com.paic.gbd.udfarray.GenericArrayUnion';
OK
Time taken: 0.031 seconds
最终函数生成了前文需要的结果