collect_set无法满足业务需要,只排重不排序。为了实现又排重又排序,重写了collect_set的底层源码。
其实就是把底层的LinkHashSet改成TreeSet。
涉及到的类
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet;
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator;
改写的是
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator;
改写的内容是
public MkArrayAggregationBuffer() {
if (bufferType == BufferType.LIST){
container = new ArrayList<Object>();
} else if(bufferType == BufferType.SET){
//container = new LinkedHashSet<Object>();
container = new TreeSet<Object>(); //由原来的LinkedHashSet改写成TreeSet
} else {
throw new RuntimeException("Buffer type unknown");
}
}
实现的效果
新建测试表
create table t_test(id string,name string);
insert into t_test select '1','a';
insert into t_test select '1','b';
insert into t_test select '1','a';
insert into t_test select 'b','d';
insert into t_test select '2','d';
insert into t_test select '2','e';
insert into t_test select '2','a';
insert into t_test select '1','h';
--用原有collect_set查询效果
select id,collect_set(name) from t_test group by id;
2 ["d","e","a"]
b ["d"]
1 ["a","b","h"]
--用重写后的my_collect_set查询效果
select id,my_collect_set(name) from t_test group by id;
2 ["a","d","e"]
b ["d"]
1 ["a","b","h"]
完整代码如下
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet
package com.hxy.udaf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import com.hxy.udaf.GenericUDAFMkCollectionEvaluator.BufferType;
/**
* GenericUDAFCollectSet
*/
@Description(name = "collect_set", value = "_FUNC_(x) - Returns a set of objects with duplicate elements eliminated")
public class GenericUDAFCollectSet extends AbstractGenericUDAFResolver {
public GenericUDAFCollectSet() {
}
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
switch (parameters[0].getCategory()) {
case PRIMITIVE:
case STRUCT:
case MAP:
case LIST:
break;
default:
throw new UDFArgumentTypeException(0,
"Only primitive, struct, list or map type arguments are accepted but "
+ parameters[0].getTypeName() + " was passed as parameter 1.");
}
return new GenericUDAFMkCollectionEvaluator(BufferType.SET);
}
}
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
package com.hxy.udaf;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.TreeSet;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
public class GenericUDAFMkCollectionEvaluator extends GenericUDAFEvaluator
implements Serializable {
private static final long serialVersionUID = 1l;
enum BufferType { SET, LIST }
// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
private transient ObjectInspector inputOI;
// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
// of objs)
private transient StandardListObjectInspector loi;
private transient ListObjectInspector internalMergeOI;
private BufferType bufferType;
//needed by kyro
public GenericUDAFMkCollectionEvaluator() {
}
public GenericUDAFMkCollectionEvaluator(BufferType bufferType){
this.bufferType = bufferType;
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
// init output object inspectors
// The output of a partial aggregation is a list
if (m == Mode.PARTIAL1) {
inputOI = parameters[0];
return ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorUtils.getStandardObjectInspector(inputOI));
} else {
if (!(parameters[0] instanceof ListObjectInspector)) {
//no map aggregation.
inputOI = ObjectInspectorUtils.getStandardObjectInspector(parameters[0]);
return ObjectInspectorFactory.getStandardListObjectInspector(inputOI);
} else {
internalMergeOI = (ListObjectInspector) parameters[0];
inputOI = internalMergeOI.getListElementObjectInspector();
loi = (StandardListObjectInspector)
ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
return loi;
}
}
}
class MkArrayAggregationBuffer extends AbstractAggregationBuffer {
private Collection<Object> container;
public MkArrayAggregationBuffer() {
if (bufferType == BufferType.LIST){
container = new ArrayList<Object>();
} else if(bufferType == BufferType.SET){
//container = new LinkedHashSet<Object>();
container = new TreeSet<Object>();//由原来的LinkedHashSet改写成TreeSet
} else {
throw new RuntimeException("Buffer type unknown");
}
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((MkArrayAggregationBuffer) agg).container.clear();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
return ret;
}
//mapside
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
Object p = parameters[0];
if (p != null) {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
putIntoCollection(p, myagg);
}
}
//mapside
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> partialResult = (ArrayList<Object>) internalMergeOI.getList(partial);
if (partialResult != null) {
for(Object i : partialResult) {
putIntoCollection(i, myagg);
}
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
List<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
private void putIntoCollection(Object p, MkArrayAggregationBuffer myagg) {
Object pCopy = ObjectInspectorUtils.copyToStandardObject(p, this.inputOI);
myagg.container.add(pCopy);
}
public BufferType getBufferType() {
return bufferType;
}
public void setBufferType(BufferType bufferType) {
this.bufferType = bufferType;
}
}
1641

被折叠的 条评论
为什么被折叠?



