1、介绍
之前的文章中讲到过如何编写Spark的UDF函数之前的文章如下:
https://blog.csdn.net/Aaron_ch/article/details/113346185
那么何为UDTF呢??又如何在Spark中使用UDTF呢??
1.1、何为UDTF
其实做过大数据的,熟悉Hive小伙伴一定知道,Hive中也有很多常用官方UDTF,
explode json_tuple get_splits
等等。
就是把一行数据,转换为多行多列。简单来讲如下:
输入 {"test01":"hhh","test02":{"test03":"yyyy","test04":"uuuu"}} 这样的字符串
输出
col1 | col2 |
hhh | yyyy |
hhh | uuuu |
1.2、如何使用
查看源码中其实是没有UDTF的相关接口信息的,去官方看下:
Spark SQL supports integration of Hive UDFs, UDAFs and UDTFs. Similar to Spark UDFs and UDAFs, Hive UDFs work on a single row as input and generate a single row as output, while Hive UDAFs operate on multiple rows and return a single aggregated row as a result. In addition, Hive also supports UDTFs (User Defined Tabular Functions) that act on one row as input and return multiple rows as output. To use Hive UDFs/UDAFs/UTFs, the user should register them in Spark, and then use them in Spark SQL queries.
能够明显看出来,Spark的UDTF函数完全用的就是Hive的,官网链接(https://spark.apache.org/docs/3.1.1/sql-ref-functions-udf-hive.html#conten
查看官方列子,可以看出来,编写的时候,直接继承org.apache.hadoop.hive.ql.udf.generic.GenericUDTF就行
1.2.1、代码实例
以解析Json字符串为例:
主体代码为:
public class AnalysisJsonToArrayUDTF extends GenericUDTF {
@Override
public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
if (args.length < 2) {
throw new UDFArgumentLengthException("At least two parameters are needed,plz check!");
}
int i;
for (i = 0; i < args.length; ++i) {
if (args[i].getCategory() != ObjectInspector.Category.PRIMITIVE || !args[i].getTypeName().equals("string")) {
throw new UDFArgumentException("get_json_arrary()'s arguments have to be string type");
}
}
//定义返回的数据的列名
ArrayList<String> fieldNames = new ArrayList<>();
fieldNames.add("col");
//定义返回的数据的列类型
ArrayList<ObjectInspector> fieldOIs = new ArrayList<>();
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] objects) throws HiveException {
if (objects.length < 2) {
throw new UDFArgumentLengthException("At least two parameters are needed,plz check!");
}
String jsonStr = objects[0].toString();
String jsonKey = objects[1].toString();
Map<Integer, Map<String, String>> keyValueMap2 = new IdentityHashMap<>();
String psKey;
String psValue;
String key;
JsonUtils.dbJSONFormatIntoMap(JsonUtils.str2FastJSON(jsonStr), keyValueMap2);
for (Map.Entry<Integer, Map<String, String>> mapEntry : keyValueMap2.entrySet()) {
for (Map.Entry<String, String> mapEntry2 : mapEntry.getValue().entrySet()) {
psKey = mapEntry2.getKey();
psValue = mapEntry2.getValue();
int keyLength = psKey.split("[.]").length;
if (keyLength < 2) {
if (jsonKey.equals(psKey)) {
forward(psValue);
}
}
key = psKey.split("[.]")[keyLength - 1];
if (keyLength >= 2 && key.equals(jsonKey) && !StringUtils.isEmpty(psValue)) {
String mp2 = "{\"" + key + "\":" + psValue + "}";
for (Map.Entry<String, Object> me2 : JsonUtils.str2FastJSON(mp2).entrySet()) {
if (me2.getValue() instanceof JSONArray) {
JSONArray jsonArray = (JSONArray) me2.getValue();
for (int i = 0; i < jsonArray.size(); i++) {
System.out.println(jsonArray.get(i));
forward(jsonArray.get(i).toString());
}
} else {
forward(JsonUtils.trimBothEndsChars(psValue, "\\[\\]"));
}
}
}
}
}
}
@Override
public void close() throws HiveException {
}
}
main函数调用
public static void main(String[] args) {
SparkConf conf = new SparkConf()
.setAppName("Sync")
.set("hive.exec.dynamici.partition", "true")
.set("hive.exec.dynamic.partition.mode", "nonstrict")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.sql.autoBroadcastJoinThreshold", "204800")
.set("spark.debug.maxToStringFields", "1000")
.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
.setMaster("local[*]");
SparkSession sparkSession = SparkSession.builder()
.config(conf).enableHiveSupport().getOrCreate();
sparkSession.sql("create temporary function marketing_json_trun as 'AnalysisJsonToArrayUDTF'");
String jsonStr = "{\"test01\":[{\"gj\":{\"sf\":\"js\"}},{\"ds\":\"nj\"},{\"ds\":\"sh\"}]}";
sparkSession.sql("select marketing_json_trun('"+jsonStr+"','ds')").show();
}