一、实现展示
hive> desc test_avg_str_in_str;
user_id int
name string
value int
hive> select * from test_avg_str_in_str;
1 awuz 1
1 azhaoz 1
2 zhangsan 2
2 lisi 2
2 wangwu 3
-- UDAF: avgStr (找到name中出现z的次数,再求平均数)
-- 难点在计算平均数的时候,中间结果需要保存 总值和计数值,需要用到 LazyBinaryStruct 结构
hive> select user_id, avgStr(name, "z") from test_avg_str_in_str group by user_id;
1 1.5
2 0.333333
PS. 这个UDAF实现的功能目前自己瞎想的,没有啥业务应用…
二、关键函数
- PARTIAL1: map阶段, 调用iterate()和terminatePartial()
- PARTIAL2: map端的Combiner阶段,调用merge() 和 terminatePartial()
- FINAL: reduce阶段,调用merge()和terminate()
// 确定各个阶段输入输出参数的数据格式ObjectInspectors
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;
// 保存数据聚集结果的类
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
// 重置聚集结果
public void reset(AggregationBuffer agg) throws HiveException;
// map阶段,迭代处理输入sql传过来的列数据
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
// map与combiner结束返回结果,得到部分数据聚集结果
public Object terminatePartial(AggregationBuffer agg) throws HiveException;
// combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果。
public void merge(AggregationBuffer agg, Object partial) throws HiveException;
// reducer阶段,输出最终结果
public Object terminate(AggregationBuffer agg) throws HiveException;
三、代码CODE
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAF