zhu1.需求原型
该需求是在lx后续数据计算时需要上一行的lx以及qx,在spark里面没有该函数可以使用 lag不行
需要的依赖
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>3.1.3</version>
</dependency>
第一个继承类 ,该类主要是返回一个实现
GenericUDAFEvaluator.AbstractAggregationBuffer 类
package fun;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
public class FieldLength extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) {
return new FieldLengthUDAFEvaluator();
}
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) {
return new FieldLengthUDAFEvaluator();
}
}
第二个继承了
package fun;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import java.util.ArrayList;
import java.util.List;
public class FieldLengthAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
List<HiveDecimalWritable> list = new ArrayList<>();
/**
* 合并值缓冲区大小,这里是用来保存字符串长度,因此设为4byte
*
* @return
*/
@Override
public int estimate() {
return JavaDataModel.JAVA64_ARRAY_META;
}
}
第三个继承类
package fun.lxFun;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
/**
* @Description: 这里是UDAF的实际处理类
* @author: willzhao E-mail: zq2599@gmail.com
* @date: 2020/11/4 9:57
*/
public class LxFunUDAFEvaluator extends GenericUDAFEvaluator {
//定义返回数据类型
private HiveDecimalWritable result;
boolean empty;
//初始化lx值
private HashMap<String, BigDecimal> map = new HashMap<>();
private HashMap<String, BigDecimal> tmp_lx = new HashMap<>();
/**
* 每个阶段都会被执行的方法,
* 这里面主要是把每个阶段要用到的输入输出inspector好,其他方法被调用时就能直接使用了
*
* @param m
* @param parameters
* @return
* @throws HiveException
*/
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
//初始化返回数据的类型
result = new HiveDecimalWritable(0);
// 给下一个阶段用的,即告诉下一个阶段,自己输出数据的类型
return PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector;
}
@Override
public AggregationBuffer getNewAggregationBuffer() {
//返回FieldLengthAggregationBuffer 给 FieldLength类
LxFunAggregationBuffer result = new LxFunAggregationBuffer();
reset(result);
return result;
}
/**
* 重置,将总数清理掉
*
* @param agg
* @throws HiveException
*/
public void reset(AggregationBuffer agg) {
//可用可不用 暂时还没用到 就初始化时有数据写入文件夹
try {
FileWriter fileWriter = new FileWriter("/root/log.txt", true);
fileWriter.write("reset:" + "\n");
fileWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* 不断被调用执行的方法,最终数据都保存在agg中
*
* @param agg
* @param parameters
* @throws HiveException
*/
public void iterate(AggregationBuffer agg, Object[] parameters) {
//核心业务逻辑 实现方法
//第一个参数
Object oneQx = parameters[0];
//第二个参数
Object twoPy = parameters[1];
Object threeAge = parameters[2];
Object four = parameters[3];
try {
FileWriter fileWriter = new FileWriter("/root/log.txt", true);
fileWriter.write("oneQx:" + oneQx + "\ntwoPy:" + twoPy + "\nthreeAge:" + threeAge + "\nfour:" + four.toString() + "\n");
fileWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
BigDecimal qx = new BigDecimal(oneQx.toString());
int py = Integer.parseInt(twoPy.toString());
int age = Integer.parseInt(threeAge.toString());
String sex = four.toString();
if (py == 1) {
map.put(age + sex + py, qx);
tmp_lx.put(age + sex + py, new BigDecimal(py));
} else {
if (tmp_lx.get(age + sex + (py - 1)) != null && map.get(age + sex + (py - 1)) != null) {
if (tmp_lx.get(age + sex + (py )) == null && map.get(age + sex + (py )) == null){
tmp_lx.put(age + sex + py, tmp_lx.get(age + sex + (py - 1)).multiply(BigDecimal.ONE.subtract(map.get(age + sex + (py - 1)))));
map.put(age + sex + py, qx);
}
}
}
try {
FileWriter fileWriter = new FileWriter("/root/log.txt", true);
fileWriter.write("tmp_lx:" + tmp_lx.get(age + sex + (py)) + "\n");
fileWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
merge(agg, tmp_lx.get(age + sex + (py)));
}
/**
* group by的时候返回当前分组的最终结果
*
* @param agg
* @return
* @throws HiveException
*/
public Object terminate(AggregationBuffer agg) {
if (empty) {
return null;
}
//list集合的值循环给如 result 每行数据的返回
if (!((LxFunAggregationBuffer) agg).list.isEmpty()) {
((LxFunAggregationBuffer) agg).list.forEach(e -> {
if (e.getHiveDecimal() != null) {
if (e.getHiveDecimal().isInt()) {
result.set(e.getHiveDecimal());
} else {
result.set(e.getHiveDecimal(), 5, 10);
}
}
});
}
try {
FileWriter fileWriter = new FileWriter("/root/log.txt", true);
fileWriter.write("方法terminate:" + result +"\n");
fileWriter.flush();
fileWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
return result;
}
/**
* 当前阶段结束时执行的方法,返回的是部分聚合的结果(map、combiner)
*调用阶段 ------- order by t5.ins_age,t5.policy_year,t5.sex 根据着排序字段当排序字段计算完毕时就会自动返回该分组数据
* @param agg
* @return
* @throws HiveException
*/
public Object terminatePartial(AggregationBuffer agg) {
return terminate(agg);
}
/**
* 合并数据,将总长度加入到缓存对象中(combiner或reduce)
*
* @param agg
* @param partial
*/
public void merge(AggregationBuffer agg, Object partial) {
//合并数据到agg 供返回数据给上一层用
if (partial != null) {
empty = false;
try {
FileWriter fileWriter = new FileWriter("/root/log.txt", true);
fileWriter.write("方法merge:" + partial + "\n");
fileWriter.flush();
fileWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
//把计算号的值放入agg里面
((LxFunAggregationBuffer) agg).list.add(new HiveDecimalWritable(partial.toString()));
}
}
}
代码里面都加了注释
说明一下 udaf的工作流程
1.先走init方法初始化定义输入和输出返回类型
2.rest初始化数据
3.iterat方式走核心业务逻辑
4.iterat方法里面的merge方法合并数据并把数据放入agg里面
5.terminate方法返回每一行的值
从方法3-到方法6 根据业务数据的条数进行循环遍历
sql
add jar hdfs://192.168.88.161:8020/jar/original-sparkFunction-1.0-SNAPSHOT.jar;
create temporary function LxFun as 'fun.FieldLength';
select LxFun(t4.qx, t4.policy_year) over (partition by t4.t_age order by t4.policy_year) as lx,
*
from t4
结果展示
结果对比
数据精度上确定比对比样本未丢失精度
数据结果一直
案例成功
实际使用在运行过程中比较复杂 有问题可以联系本人~