一、将pmml模型放在对应目录下
二、编写测试类
1.引入依赖包
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.5</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.5</version>
</dependency>
2.编写测试类
package com.idata.web.controller;
import java.io.FileInputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.TargetField;
import org.jpmml.model.PMMLUtil;
import org.springframework.util.StringUtils;
public class UsePmml {
private static ModelEvaluator<?> evaluate;
private static PMML pmml;
private static FileInputStream inputStream;
static {
//获取EMQ连接信息
try {
inputStream = new FileInputStream("D:\\wincheer\\python_basic\\电极法校准\\rfr.pmml"); // 已训练好的RandomForestRegressor模型
pmml = PMMLUtil.unmarshal(inputStream);
evaluate = new ModelEvaluatorBuilder(pmml).build();
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws Exception {
// 导入模型。 模型初始化时间较久,最好在Web启动的时候实例化evaluate
// 构建输入参数
Map<String, Object> paramData = new HashMap<>();
paramData.put("x1", 39.35); // 监测点1-COD
paramData.put("x2", 6.22); // 监测点1-氨氮
paramData.put("x3", 7.06); // 监测点1-PH
paramData.put("x4", 24.2); // 监测点1-水温
paramData.put("x5", 38.43); // 监测点2-COD
paramData.put("x6", 0.54); // 监测点2-氨氮
paramData.put("x7", 7.66); // 监测点2-PH
paramData.put("x8", 30); // 监测点2-水温
paramData.put("x9", 14.94); // 监测点3-化学-COD
paramData.put("x10", 0.01); // 监测点3-化学-氨氮
paramData.put("x11", 0.2); // 监测点3-化学-总磷
paramData.put("x12", 7.31); // 监测点3-PH
paramData.put("x13", 303.70001); // 监测点3-电导率
paramData.put("x14", 98.27); // 监测点3-浊度
paramData.put("x15", 25.6); // 监测点3-水温
paramData.put("x16", 329504.0938); // 监测点3-流量
paramData.put("x17", 0.01); // 监测点3-液位
paramData.put("x18", 20.7); // 甲烷
paramData.put("x19", 0); // 氧气
paramData.put("x20", 0); // 氨气
paramData.put("x21", 0); // 一氧化碳
paramData.put("x22", 0); // 硫化氢
paramData.put("x23", 0); // 二氧化硫
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<InputField> inputFields = evaluate.getInputFields(); // 模型参数
for (InputField inputField : inputFields) {
// 将参数通过模型对应的名称进行添加
FieldName inputFieldName = inputField.getName(); // 获取模型中的参数名
Object paramValue = paramData.get(inputFieldName.getValue()); // 获取模型参数名对应的参数值
FieldValue fieldValue = inputField.prepare(paramValue); // 将参数值填入模型中的参数中
arguments.put(inputFieldName, fieldValue); // 存放在map列表中
}
// 开始评估/预测
Map<FieldName, ?> target = evaluate.evaluate(arguments);
// 获取评估/预测结果
List<TargetField> targetFields = evaluate.getTargetFields();
Object targetFieldValue = target.get(targetFields.get(0).getName());
String tempVal=targetFieldValue.toString();
Object correctingVal="";
if(tempVal!=null && tempVal.indexOf("=")!=-1){
String temp= tempVal.split("=")[1];
correctingVal=temp.substring(0,temp.length()-1);
}
System.out.println("targetFieldValue: " + ToFloat(correctingVal));
}
/**
* 转float
* @param val
* @return
*/
public static float ToFloat(Object val){
if(StringUtils.isEmpty(val)){
return 0;
}else{
float v=0;
v=Float.parseFloat(val.toString());
return v;
}
}
}