UDF函数开发

Hive有两个不同的接口编写UDF程序。一个是操作简单数据类型的UDF接口,一个是操作复杂类型的GenericUDF接口。

UDF接口暂不讨论。

下面看一个例子:

ary建表语句:

create table ary(
array1 array<struct<id:string,name:string>>,
array2 array<struct<addr:string,id:string,dt:String>>
)
row format delimited fields terminated by '\t'
map keys terminated by ','; 

ary表中数据如下:

hive > select * from ary;
OK
ary.array1      ary.array2
[{"id":"001","name":"明明"}]    [{"addr":"杨浦","id":"001","dt":"20180703"}]
[{"id":"002","name":"阿达"}]    [{"addr":"黄浦","id":"004","dt":"20180629"}]
[{"id":"003","name":"阿珂"}]    [{"addr":"浦东","id":"003","dt":"20180817"}]
[{"id":"004","name":"小宝"}]    [{"addr":"松江","id":"002","dt":"20180623"}]

我们需要的结果:

hive > select aa(array1,array2,'id') from ary;                
OK
_c0
[{"id":"001","name":"明明","addr":"杨浦","dt":"20180703"}]
[{"id":"NULL","name":"NULL","addr":"NULL","dt":"NULL"}]
[{"id":"003","name":"阿珂","addr":"浦东","dt":"20180817"}]
[{"id":"NULL","name":"NULL","addr":"NULL","dt":"NULL"}]

函数逻辑:

操作一行数据,通过传入的字段名对两个结构体类型的数组中字段值进行匹配,值相等则返回匹配到的全部值(当然id只匹配一个),否则全部置空

编写UDF自定义函数:

package com.paic.gbd.udfarray;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.logging.Log;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.lazy.LazyString;
import org.apache.hadoop.hive.serde2.lazy.LazyStruct;
import org.apache.hadoop.hive.serde2.lazy.LazyMap;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import java.util.ArrayList;
import java.util.List;
 
public class GenericArrayUnion extends GenericUDF {

	//定义输出日志,方便调试,在hive.log里面查看,该文件目录在hive-log4j.properties配置文件里查看
	 static final Log LOG = LogFactory.getLog(GenericArrayUnion.class.getName());
     输入变量定义
	 private static final int ARG_COUNT = 3; // Number of arguments to this UDF
	 private static final String FUNC_NAME = "array_union"; // External Name
     private ListObjectInspector arrayOI;
     private ListObjectInspector arrayOI2;
     private ObjectInspector strOI;
     
     private StructObjectInspector structOI;
     private StructObjectInspector structOI2;
     private ArrayList<Object> valueList = new ArrayList<Object>(); 
     private ArrayList<Object> valueList2 = new ArrayList<Object>(); 
     int num1=0;
     int num2=0;
     int valueListLength=0;
     int valueList2Length=0;
     @Override
     public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    	  arrayOI = (ListObjectInspector)arguments[0];
    	  arrayOI2 = (ListObjectInspector)arguments[1];
    	  strOI = (ObjectInspector)arguments[2];
    	  
          structOI = ((StructObjectInspector)arrayOI.getListElementObjectInspector());
          structOI2 = ((StructObjectInspector)arrayOI2.getListElementObjectInspector());
          
          //判断参数个数
          if (arguments.length != ARG_COUNT) {
  			throw new UDFArgumentException("The function " + FUNC_NAME
  					+ " accepts " + ARG_COUNT + " arguments.");
  			}
          //判断参数1,2的数据类型
          if (!arguments[0].getCategory().equals(Category.LIST) || !arguments[1].getCategory().equals(Category.LIST)) {
  			throw new UDFArgumentException(
  					"\""
  							+ org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME
  							+ "\" "
  							+ "expected at function arguments, but \"" + arguments[0].getTypeName()  
  							+ "\" and \"" + arguments[1].getTypeName() + "\" is found.");
  			}
         //数组中结构体结构必须完全一致
  		 if ( structOI.getCategory() != Category.STRUCT 
  				|| structOI2.getCategory() != Category.STRUCT) {
  			throw new UDFArgumentException(
  					"\""
  							+ org.apache.hadoop.hive.serde.serdeConstants.STRUCT_TYPE_NAME
  							+ "\" "
  							+ "expected at ARRAY elements type, but \"" + structOI.getCategory()  
  							+ "\" and \"" + structOI2.getCategory() + "\" is found.");
  		 }
          //判断参数3的数据类型
          if(!arguments[2].getCategory().equals(Category.PRIMITIVE)) {
        	  throw new UDFArgumentException(
    					"\""
    							+ org.apache.hadoop.hive.serde.serdeConstants.PrimitiveTypes
    							+ "\" "
    							+ "expected at function arguments, but \"" + arguments[2].getTypeName()  
    							+  "\" is found.");
          }
          
          //输出结构体类型的数组定义
          ArrayList structFieldNames = new ArrayList();
          ArrayList structFieldObjectInspectors = new ArrayList();
          String value1=null;
          String value2=null;
          valueListLength=0;
          int strOINum1=0;
          for(int i=0;i<structOI.getAllStructFieldRefs().size();i++) {
        	  valueListLength++;
        	  value1=structOI.getAllStructFieldRefs().get(i).toString().substring(2);
        	  structFieldNames.add(value1);
        	  if(ObjectInspectorUtils.getWritableConstantValue(arguments[2]).toString().equals(value1)) {
        		  //定位关联字段在第一个数组中位置
        		  num1=i;
        		  strOINum1++;
        	  }
          }
          if(strOINum1!=1) {
        	  throw new UDFArgumentException(
        			  "第一个数组中必须有一个传参字段"
        			  );
          }
          valueList2Length=0;
          int strOINum2=0;
          for(int i=0;i<structOI2.getAllStructFieldRefs().size();i++) {
        	  valueList2Length++;
        	  value2=structOI2.getAllStructFieldRefs().get(i).toString().substring(2);
        	  if(!ObjectInspectorUtils.getWritableConstantValue(arguments[2]).toString().equals(value2)) {
        		  structFieldNames.add(value2);
        	  }else {
        		//定位关联字段在第二个数组中位置
        		  num2=i;
        		  strOINum2++;
        	  }
          } 
          if(strOINum2!=1) {
        	  throw new UDFArgumentException(
        			  "第二个数组中必须有一个传参字段"
        			  );
          }
	   
	      for(int i=0;i<structFieldNames.size();i++) {
	    	  structFieldObjectInspectors.add( PrimitiveObjectInspectorFactory.writableStringObjectInspector );
	      }

          StructObjectInspector si2;
          si2 = ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldObjectInspectors); 
         
          return ObjectInspectorFactory.getStandardListObjectInspector(si2);
     }
 
     @Override
     public Object evaluate(DeferredObject[] arguments) throws HiveException{
    	 LOG.info("hello1");
    	 LOG.info(num1);LOG.info(num2);
    	 LOG.info(valueListLength);LOG.info(valueList2Length);
    	 
	  int arr_len = arrayOI.getListLength(arguments[0].get());
      
      valueList.clear();
      valueList2.clear();

      
      //遍历array
      String[] value;
      value=new String[valueListLength+valueList2Length-1];
      LOG.info("hello2");
      LOG.info(arr_len);
	  for(int i=0;i<arr_len;i++)
	  { 	//获取第一个数组值
               LazyStruct LStruct = (LazyStruct)arrayOI.getListElement(arguments[0].get(),i);
               //获取struct中的所有value值
               StructObjectInspector s_arrayElementOI = (StructObjectInspector) structOI;
   			   List field_value = s_arrayElementOI.getStructFieldsDataAsList(LStruct);
               valueList.addAll(field_value);
               
            //获取第二个数组值   
               LazyStruct LStruct2 = (LazyStruct)arrayOI2.getListElement(arguments[1].get(),i);
               //获取struct中的所有value值
               StructObjectInspector s_arrayElementOI2 = (StructObjectInspector) structOI2;
   			   List field_value2 = s_arrayElementOI2.getStructFieldsDataAsList(LStruct2);
               valueList2.addAll(field_value2);
               
               LOG.info("hello3");
               //关联字段值相同则添加到输出,否则输出"NULL"
               if(valueList.get(num1).toString().equals(valueList2.get(num2).toString())) {
            	   //添加第一个数组值到输出
            	   LOG.info(valueList.size());
            	   LOG.info(valueList2.size());
            	   for(int j=0;j<valueList.size();j++) {
            		  value[j]=valueList.get(j).toString(); 
            		  LOG.info(value[j]);
            	   }
            	   
            	   //添加第二个数组值到输出
            	   for(int j=0,n=0;j<valueList2.size();j++) {
            		   //添加变量n,用value数组索引j+valueList.size()会出现数组越界
            		   if(j!=num2) {
            			   value[n+valueList.size()]=valueList2.get(j).toString();
            			   LOG.info(value[n+valueList.size()]);
            			   n++; 
            		   }
            	   }
               }else {
            	   for(int j=0;j<value.length;j++) {
            		   value [j]="NULL";
            	   }
               }
              

               
               valueList.clear();
               valueList2.clear();
          }
	  LOG.info("hello8");
	  LOG.info(value.length);
	  Object[] e;	
	  e = new Object[value.length];
	  for(int i=0;i<value.length;i++) {
		  LOG.info(i);
		  e[i]=new Text(value[i]);
	  }
	  
	  ArrayList result = new ArrayList();
	  result.add(e);
	  LOG.info("hello10");
      return result;
     }
   
     @Override
     public String getDisplayString(String[] children) {
          assert( children.length>0 );
 
          StringBuilder sb = new StringBuilder();
          sb.append("array_union(");
          sb.append(children[0]);
          sb.append(")");
 
          return sb.toString();
     }

}

将代码导成jar包上传到linux上,在hive命令行中添加jar包,并创建函数

hive > add jar /home/hadoop/aa.jar;                                    
Added /home/hadoop/aa.jar to class path
Added resource: /home/hadoop/aa.jar
hive > create function aa as 'com.paic.gbd.udfarray.GenericArrayUnion';
OK
Time taken: 0.031 seconds

最终函数生成了前文需要的结果

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值