题目说明
编写一个udf,输入这个数组之后按多列输出
题解
udtf其实是udf里面比较少自己去写的东西,所以反而是盲区,这种题目就是摸过的就觉得简单,所谓难者不会,会者不难
代码会放在最后,我说一下精髓部分!
凡是UDF编写,关键点是了解计算特征,udtf其实关键点就是输入一行,可以输出多行,不管里头怎么折腾,反正只要是Java代码写的,输入一行的话我们无非都是一个输入参数,输出多行要么是数组,要么是集合。
输入输出参数说明
输入部分,其实就是select 函数的时候给的那一排参数
比如select udf(1,2,3) 那么args(0)=1,args(1)=2,args(3)=3
题目给的是数组,数组在输入参数的时候我们当成复杂参数来对待,这里不能混淆
比如select udf([1,2,3],[4,5,6])的时候,对应的就是args(0)=[1,2,3],args(1)=[4,5,6]
题目中的是select udf([1,2,3])的形式
输出部分,正如我们说的,不是类型就是数组类型,udtf中输出的其实是一个数组的类型,有一个元素就
forward(array)会调用一次,在process里面调用多次的时候就是多个结果输出
初始化部分
初始化部分其实是给一个表头,也就是结果中的默认列,因为可以输出多行嘛,需要有列名,一个是列名,一个是列的类型:
对于结果的话:
源码
最后,附上代码源码
package org.apache.spark.udf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
public class MultiplyRow extends GenericUDTF {
@Override
public void close() throws HiveException {
}
@Override
public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
if (args.length != 1) {
throw new UDFArgumentLengthException("ExplodeMap takes only one argument");
}
if (args[0].getCategory() != ObjectInspector.Category.LIST) {
throw new UDFArgumentException("ExplodeMap takes array as a parameter");
}
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("row");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
fieldNames.add("value");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
List<String> ls= (ArrayList) args[0];
for (int i = 0; i < ls.size(); i++) {
try {
String[] result=new String[2];
result[0]="row"+String.valueOf(i);
result[1]=ls.get(i);
forward(result);
} catch (Exception e) {
continue;
}
}
}
}
这部分是测试代码
package org.apache.spark;
import org.apache.spark.sql.SparkSession;
import java.io.File;
public class MultiplyRowTest {
public static void main(String[] args) {
String warehouseLocation = new File("spark-warehouse").getAbsolutePath();
SparkSession spark = SparkSession
.builder()
.appName("MultiplyRowTest")
.config("spark.sql.warehouse.dir", warehouseLocation)
.master("local[*]")
.enableHiveSupport()
.getOrCreate();
// spark.sql("CREATE TEMPORARY FUNCTION myudf as 'org.apache.spark.udf.UserDefinedUDTF'");
spark.sql("create temporary function udtf as 'org.apache.spark.udf.MultiplyRow'");
spark.sql("select udtf(split('a,b,c',',')) ").show();
// spark.sql("select 'a' as c1, myudf('a,b,c') as array ").show();
}
}