背景
最近在使用阿里的机器学习pai进行模型训练,训练出来的模型需要给java进行调用来使用。本博客阐述java调用pmml进行预测的过程。
实战
获取pmml模型
在机器学习pai->模板实验->心脏病预测,可以直接训练并且导出模型,具体步骤参考:https://help.aliyun.com/document_detail/34929.html?spm=a2c4g.11186623.6.676.5394607a96UDkz
java调用pmml
依赖支持库jpmml:https://github.com/jpmml/jpmml-evaluator.git
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.3.5</version>
</dependency>
这里使用的是1.3.5,跟阿里的版本保持一致,避免出现版本兼容性的问题。
下载pmml的模型文件:https://download.csdn.net/download/wangjie5540/12101924
java源码
@Log4j2
public class PmmlPredict {
public static Evaluator evaluator;
public static void initModel() throws IOException, SAXException, JAXBException {
File file = new File("/home/wang/Downloads/lr_demo.pmml");
PMML pmml = null;
try (InputStream is = new FileInputStream(file)) {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
}
evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
evaluator.verify();
}
public static Integer predict(JSONObject feature) {
// 获取模型定义的特征
List<? extends InputField> inputFields = evaluator.getInputFields();
log.info("模型的特征是:{}", inputFields);
// 获取模型定义的目标名称
List<? extends TargetField> targetFields = evaluator.getTargetFields();
log.info("目标字段是:{}", targetFields);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
FieldName inputName = inputField.getName();
String name = inputName.getValue();
Object rawValue = feature.get(name);
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
for (Map.Entry<FieldName, ?> entry : results.entrySet()) {
log.info(entry.getKey());
log.info(entry.getValue());
}
Map resultRecord = (Map) EvaluatorUtil.decode(results);
Object res = resultRecord.get(FieldName.create("y"));
log.info("预测结果:");
log.info(results);
log.info(resultRecord);
log.info(res);
log.info(evaluator.getActiveFields());
log.info(evaluator.getTargetFields());
log.info(evaluator.getSummary());
log.info(evaluator.getMiningFunction());
return null;
}
public static void main(String[] args) throws JAXBException, SAXException, IOException {
PmmlPredict.initModel();
JSONObject jsonObject = new JSONObject();
jsonObject.put("sex", 1);
jsonObject.put("address", 1);
jsonObject.put("famsize", 1);
jsonObject.put("pstatus", 0);
jsonObject.put("medu", 0);
jsonObject.put("fedu", 1);
jsonObject.put("mjob", 1);
jsonObject.put("fjob", 0);
jsonObject.put("guardian", 1);
jsonObject.put("traveltime", 0);
jsonObject.put("studytime", 0.3);
jsonObject.put("failures", 0.3);
jsonObject.put("schoolsup", 0);
jsonObject.put("fumsup", 1);
jsonObject.put("paid", 0);
jsonObject.put("activities", 0);
jsonObject.put("higher", 0);
jsonObject.put("internet", 1);
jsonObject.put("famrel", 0);
jsonObject.put("freetime", 0.7);
jsonObject.put("goout", 0.5);
jsonObject.put("dalc", 0.7);
jsonObject.put("walc", 0);
jsonObject.put("health", 0.5);
jsonObject.put("absences", 0.8);
// 1,1,0,0,1,1,0,1,0,0.3333333333333333,0.3333333333333333,0,1,0,0,0,1,0,0.75,0.5,0.75,0,0,0.5,0.08,0
PmmlPredict.predict(jsonObject);
}
}