一、UDF说明
collect_map(x,y):Returns a map of entries formed by taking x as the key and y as the value. Groups with duplicate keys will contain one entry. the value of the entry in indeterminat.
返回以x为键,y为值形成的map。具有重复key的组将只包含一个条目,其值是不确定的。
二、 代码
package com.scb.dss.udaf;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import java.util.HashMap;
import java.util.Map;
/**
* GenericUDAFCollectMap
*/
@Description(name = "collect_map", value = "_FUNC_(x,y) - Returns a map of entries formed by taking x as the key and y as the value." +
"Groups with duplicate keys will contain one entry. the value of the entry in indeterminate")
public class GenericUDAFCollectMap extends AbstractGenericUDAFResolver {
static final Log LOG = LogFactory.getLog(GenericUDAFCollectMap.class.getName());
public GenericUDAFCollectMap() {
}
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 2) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly two arguments are expected.");
}
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but " +
parameters[0].getTypeName() + " was passed as parameter 1.");
}
return new GenericUDAFCollectMapEvaluator();
}
static class CollectMapAggregationBuffer implements AggregationBuffer {
Map<Object, Object> container;
}
public static class GenericUDAFCollectMapEvaluator extends GenericUDAFEvaluator {
private PrimitiveObjectInspector keyOI;
private ObjectInspector valueOI;
private StandardMapObjectInspector internalMergeOI;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
// init output object inspectors
// The output of a partial aggregation is a list
if (m == Mode.PARTIAL1) {
keyOI = (PrimitiveObjectInspector) parameters[0];
valueOI = parameters[1];
return ObjectInspectorFactory
.getStandardMapObjectInspector(ObjectInspectorUtils
.getStandardObjectInspector(keyOI),
ObjectInspectorUtils
.getStandardObjectInspector(valueOI));
} else if (m == Mode.PARTIAL2 || m == Mode.FINAL) {
internalMergeOI = (StandardMapObjectInspector) parameters[0];
keyOI = (PrimitiveObjectInspector) internalMergeOI.getMapKeyObjectInspector();
valueOI = internalMergeOI.getMapValueObjectInspector();
return ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
} else {
keyOI = (PrimitiveObjectInspector) ObjectInspectorUtils
.getStandardObjectInspector(parameters[0]);
valueOI = ObjectInspectorUtils.getStandardObjectInspector(parameters[1]);
return ObjectInspectorFactory
.getStandardMapObjectInspector(ObjectInspectorUtils
.getStandardObjectInspector(keyOI),
ObjectInspectorUtils
.getStandardObjectInspector(valueOI));
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((CollectMapAggregationBuffer) agg).container = new HashMap<Object, Object>();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
CollectMapAggregationBuffer ret = new CollectMapAggregationBuffer();
reset(ret);
return ret;
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 2);
Object k = parameters[0];
Object v = parameters[1];
if (k != null) {
CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg;
put(k, v, myagg);
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg;
HashMap<Object, Object> ret = new HashMap<Object, Object>(myagg.container.size());
ret.putAll(myagg.container);
return ret;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg;
Map<Object, Object> partialResult = (Map<Object, Object>) internalMergeOI.getMap(partial);
for (Map.Entry<Object, Object> e : partialResult.entrySet()) {
put(e.getKey(), e.getValue(), myagg);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
CollectMapAggregationBuffer myagg = (CollectMapAggregationBuffer) agg;
HashMap<Object, Object> ret = new HashMap<Object, Object>(myagg.container.size());
ret.putAll(myagg.container);
return ret;
}
private void put(Object k, Object v, CollectMapAggregationBuffer myagg) {
Object kCopy = ObjectInspectorUtils.copyToStandardObject(k,
this.keyOI);
Object vCopy = ObjectInspectorUtils.copyToStandardObject(v,
this.valueOI);
myagg.container.put(kCopy, vCopy);
}
}
}
三、测试
测试数据如下:
class | struct |
1 | {"name":"N003","age":"20"} |
2 | {"name":"N001","age":"18"} |
1 | {"name":"N002","age":"19"} |
测试代码:
SELECT class, collect_map(class, struct_t) as res
FROM (
SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
union all
SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
union all
SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
) as test_data
group by class;
测试结果:
可以发现,具有重复key的这组(class=1),最终的value只有一个了。
为了解决这个问题,我们可以在加上一层collect_list处理。
SELECT class, collect_map(class, struct_array) as res
FROM (
SELECT class, collect_list(struct_t) as struct_array
FROM (
SELECT '1' as class, named_struct('name', 'N003', 'age', '20') as struct_t
union all
SELECT '2' as class, named_struct('name', 'N001', 'age', '18') as struct_t
union all
SELECT '1' as class, named_struct('name', 'N002', 'age', '19') as struct_t
) as test_data
group by class
) as tmp
group by class
;
四、参考文档
https://issues.apache.org/jira/secure/attachment/12620274/Collect_map.patch