大体流程
用python将生成的模型导出为pmml格式,java通过加载本地pmml文件的形式读取模型,给予参数,实现预测或分类,本文用spring boot,接口接收前端各参数数值,预测后返回给前端使用,大致这么个流程
python方面
生成模型并导出pmml的代码
def exportPmml():
iris = load_iris()
# 创建带有特征名称的 DataFrame
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
# 创建模型管道
iris_pipeline = PMMLPipeline([
("classifier", RandomForestClassifier())
])
# 训练模型
iris_pipeline.fit(iris_df, iris.target)
# 导出模型到 RandomForestClassifier_Iris.pmml 文件
sklearn2pmml(iris_pipeline, "RandomForestClassifier_Iris.pmml")
return None
网上可以搜到挺多,这里是用的JPMML导出Pmml,pipeline用的是PMMLPipeline,其实也是实现的sklearn的pipeline,如果用Nyoka,pipeline用的是sklearn的pipeline,差不多,它们的github主页都提供了Usage和多种模型案例
Java方面
一个坑
网上搜到的代码在加载pmml模型是如果遇到报错提示Exception in thread "main" java.lang.IllegalArgumentException: http://www.dmg.org/PMML-4_4
,尝试把文件里第二行版本号修改为
另一个小点
maven里有关pmml的包的版本号1.4.x与1.5.x的代码有差别,导的版本不匹配,可能导的东西都找不到,而且名字也会不一样
代码
@Component
public class LoadModel1_5_x {
private Evaluator evaluator;
public String modelUrl="E:\\Softwares\\Learn\\IntelliJ_Projects\\IrisModelTest\\src\\main\\resources\\static\\RandomForestClassifier_Iris.pmml";
public LoadModel1_5_x() {
PMML pmml;
// 模型导入
try {
if (modelUrl != null) {
File file = new File(modelUrl);
InputStream inputStream = new FileInputStream(file);
InputStream is = inputStream;
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
//1.5.x版本 使用一个pmml配置了两个不同的evaluators
ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
//第一种
this.evaluator = modelEvaluatorBuilder.build();
System.out.println("模型build成功!");
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (SAXException e) {
e.printStackTrace();
} catch (JAXBException e) {
e.printStackTrace();
}
}
public Object predict(Map<String,Double> irismap){
List<InputField> inputFields = evaluator.getInputFields();
// 过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = irismap.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
//对于分类问题等有多个输出。
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.err.println("target: " + targetFieldName.getValue()
+ " value: " + targetFieldValue);
}
return results;
}
}
然后controller里autowired一个模型,构造一个参数map
Map<String, Double> map = new HashMap<String, Double>();
map.put("sepal length (cm)", (double) 7);
map.put("sepal width (cm)", (double) 3.2);
map.put("petal length (cm)", (double) 4.7);
map.put("petal width (cm)", (double) 1.4);
使用模型类的predict方法就可以简单实现啦