概要
最近接了一个项目,项目里面有个需求,需要对接第三方的python模型,需求就是在系统中进行问卷的填写,得到对应的分数组合,然后把分数组合调用对方的python模型得到相应的预测值。第三方给了一段python调用lib结尾的模型文件的示例代码,刚开始想着把这段代码通过AI工具转成java试试,后来发现行不通,因为对应的引用库啥的都没有;后来又想了一下,搭一个python环境,把代码部署到服务器上,java去调。后来发现也行不通。然后找第三方沟通了一下,说是可以把对应的模型转成pmml文件,然后用java调用去调用pmml文件也能实现。跟着这个思路,最终实现了需求
实现步骤一:pmml文件放到项目中
废话不多说,开弄:
对方给的pmml文件:
在项目的webapp目录下建立一下pmml目录,把文件放在项目的该目录下(也可以放在其他目录,java代码读取文件的时候路径对应上就行)
实现步骤二:引入对应的maven依赖
需要看一下对应的pmml文件版本号是多少,随便点一个文件看看即可
我这边是4.4版本的,引入对应依赖
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.5.11</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.5.11</version>
</dependency>
如果是4.4以下的版本,4.3版本这种可以引入以下maven试试:
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>
实现步骤三:java实现调用
import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.model.PMMLUtil;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@Slf4j
public class MoxibustionService {
private static Map<String, Evaluator> evaluatorMap = new HashMap<>();
private static Evaluator initModelEvaluator(){
Evaluator modelEvaluator = null;
// 检查是否已经存在该key对应的对象
if (evaluatorMap.containsKey(modelName)) {
//存在直接map里面取,无需加载,不然模型每次加载比较耗时
modelEvaluator = evaluatorMap.get(modelName);
}
if (Objects.isNull(modelEvaluator)) {
PMML pmml;
try {
log.info("init:" + modelName);
//根据文件的目录 获取文件
Resource resource = new ClassPathResource("pmml/" + modelName + ".pmml");
InputStream is = resource.getInputStream();
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info("InputStream close error!");
}
ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
modelEvaluator = modelEvaluatorBuilder.build();
modelEvaluator.verify();
log.info("加载" + modelName + "模型成功!");
//加载完毕存储到map,下次使用的时候不用在初始化 提升效率
evaluatorMap.put(modelName, modelEvaluator);
} catch (Exception e) {
log.info("加载" + modelName + "模型失败!", e.getMessage());
throw new XKSHException("加载" + modelName + "模型失败!");
}
}
return modelEvaluator;
}
/**
* 获取目标字段名称
*
* @return
*/
public static String getTargetName () {
return modelEvaluator.getTargetFields().get(0).getName().toString();
}
/**
* 使用模型生成概率分布
*
* @param inputFeature
* @return
*/
private static ProbabilityDistribution getProbabilityDistribution (Integer[] inputFeature,String modelName) {
//加载模型文件
Evaluator modelEvaluator = initModelEvaluator(modelName);
Map<FieldName, Number> paramMap = new HashMap<>();
List<InputField> inputFields = modelEvaluator.getInputFields();
for (int i = 0; i < inputFields.size(); i++) {
InputField inputField = inputFields.get(i);
//赋值
paramMap.put(inputField.getName(), inputFeature[i]);
}
Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(paramMap);
FieldName fieldName = FieldName.create(getTargetName());
return (ProbabilityDistribution) evaluateResult.get(fieldName);
}
/**
* 预测
*
* @param inputFeature
* @return
*/
public static ValueMap<String, Double> predictProba (Integer[] inputFeature,String modelName) {
ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(inputFeature,modelName);
return probabilityDistribution.getValues();
}
public static void main(String[] args) throws Exception {
Integer[] inputFeature = {1, 0, 3, 1, 1, 2, 2, 2, 3, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0};
String modelName = "身体体质";
ValueMap<String, Double> values = MoxibustionService.predictProba(inputFeature,modelName);
System.out.println("计算结果:" + values);
}
}
运行
注意事项
在开发过程中,刚开始的时候老是提示获取不到对应的文件,提示项目中文件不存在,我就先去项目的target编译文件classess中查看了一下,发现编译文件中确实没有pmml目录,那就基本就是项目编译问题,然后去pom文件配置中,找了一下,发现配置了项目bulid的时候只处理了xml格式的文件:
然后在下面加上pmml格式的文件,重新编译了一下,发现target编译文件classess中有了pmml文件,再执行代码也执行成功了