前言
这一章来学习如何使用 java 调用 python 机器学习模块,毕竟 python 在算法方法好用,但是做 web 项目还是 java 更优,最近有个项目想要集成机器学习算法,这里简单记录一下(默认使用idea开发工具,默认都会创建maven项目,不会自行百度)。
java 调用 python,分三步来学习:
- 第一步:java 调用 python 语句
- 第二步:java 调用 python 脚本
- 第三步:java调用python脚本函数(如何传递参数)
- 第四步:java调用python机器学习模块并运行
上面三步都需要调用 jython 库,两种加入项目方法:
- 从官网下载 jar 包,手动加入
- 在 pom 文件中直接加入依赖
<dependency>
<groupId>org.python</groupId>
<artifactId>jython-standalone</artifactId>
<version>2.7.0</version>
</dependency>
1、java 调用 python 语句
首先要在idea中导入jython库,上面提到的两种方法,我使用的是第二种,即在 pom 文件中加入依赖,简单明了。
创建一个 javaRunPython 类,执行下列代码:
PythonInterpreter interpreter = new PythonInterpreter();
interpreter.exec("a=1+2; ");
interpreter.exec("print(a);");
运行结果如下:
另外,我发现一个有趣的现象,无论是"print a",还是"print(a)",居然都没有报错,也就是说 jython 兼容支持 python2 和 python3 两种语法。
2、java 调用 python 脚本
2.1 PythonInterpreter 调用 python 脚本
首先要写一个 python 脚本用来被调用,内容随意,下面是我写的:
hello = 'hello world, this is using java to pring python word'
print(hello)
然后在 maven 项目中,运行:
import org.python.util.PythonInterpreter;
PythonInterpreter interpreter = new PythonInterpreter();
interpreter.execfile("E:\\pythonTest.py");
运行结果如下:
上面调用的只是普通的 python 脚本,如果脚本中导入了第三方库,还能不能运行呢?测试一下,写一个简单的生成矩阵 python 脚本:
print("sdafd")
import numpy as np
n = np.arange(0, 30, 2)
n = n.reshape(3, 5)
print(n)
print("dafdafda")
运行结果:
由上图可以看到,只执行了第一句,当导入第三方库的时候报错了,这个是因为在 jython 库中不存在 numpy 模块,自然会报错。
由此可以得出:PythonInterpreter 可以简单执行普通的 python 脚本,但是对于带有第三方库的 python 脚本就不行了。
2.2 使用 Runtime 调用 python 脚本(推荐)
先来个普通脚本:
hello = 'hello world, this is using java to pring python word'
print(hello)
文件名为 Runtime.py
Runtime 方法的 java 代码如下:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
Process proc;
try {
proc = Runtime.getRuntime().exec("python E:\\Runtime.py");
BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
String line = null;
while ((line = in.readLine()) != null) {
System.out.println(line);
}
in.close();
proc.waitFor();
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
执行结果和上面一样
再来看看带有第三方库的脚本:
print("sdafd")
import numpy as np
n = np.arange(0, 30, 2)
n = n.reshape(3, 5)
print(n)
print("dafdafda")
因为文件名和位置没有变化,Runtime 的代码不变,再次执行,结果如下:
可以看到,完美执行,没有报错。
这里说一下为什么没有报错:
Runtime 执行本质类似于控制台调用 python.exe 去执行 python 脚本,也就是调用的 python 线程,python 仓库中本身是有 numpy 模块的,自然就不会报错了;
proc = Runtime.getRuntime().exec("python E:\\Runtime.py");
上方的代码中,可以看到 python 字样,这个是调用 python 程序的,只是 python 配置了环境变量,否则就要换成 D:\Python3.8.6\python.exe 了。
该方法有一个缺点,它得到 python 执行结果是通过数据流得到的,每次读取一行,相当于每执行一次就要读取一次结果,这就会导致我们运行的过程中很耗时。
3、java 调用 python 脚本函数(如何传递参数)
这次写一个带参数的 python 脚本——简单的两数相加:
def add(x,y):
return x+y
脚本名称为 add.py,java 调用 python 函数代码为:
PythonInterpreter interpreter = new PythonInterpreter();
interpreter.execfile("E:\\RunTime.py");
// 第一个参数为期望获得的函数(变量)的名字,第二个参数为期望返回的对象类型
PyFunction pyFunction = interpreter.get("add", PyFunction.class);
int a = 5, b = 10;
//调用函数,如果函数需要参数,在Java中必须先将参数转化为对应的“Python类型”
PyObject pyobj = pyFunction.__call__(new PyInteger(a), new PyInteger(b));
System.out.println("the anwser is: " + pyobj);
运行结果:
4、java 调用 python 机器学习模型
java 调用 python 机器学习模型,我总结了一下共有四种:
- 利用上面的 java 调用 python 脚本——训练集和测试集写入文本中,python 脚本进行读取
- 将 python 训练的模型参数保存到文本中,用 java 代码重现模型的预测算法。这种工作量很大,而且出现的 bug 几率大大增加。最重要的是很多深度学习的框架就没办法用了。
- 使用 python 进程运行深度学习中训练的模型,在 java 应用程序中调用 python 进程提供的服务。这种方法没尝试过。python 语言写得程序毕竟还是在 python 环境中执行最有效率。而且 python 应用和 java 应用可以运行在不同的服务器上,通过进程的远程访问调用。
- 将机器学习模型保存为 pmml 文件,然后 java 调用 pmml 文件。这种方法是网上最常见的方法,进行上线部署的时候,不会依赖于 python 环境,推荐使用。
上面四种方法,前三种都需要依赖于 python 环境,如果要部署的系统中存在 python 环境那么使用前三种是可以的,如果没有,那么第四种方法是最优的。
下面对于 pmml 进行介绍和实例。
4.1 pmml 介绍
PMML:Predictive Model Markup Language 预测模型标记语言。data mining group 推出的,有十多年的历史了。是一种可以呈现预测分析模型的事实标准语言。标准东西的好处就是,各种开发语言都可以使用相应的包,把模型文件转成这种中间格式,而另外一种开发语言,可以使用相应的包导入该文件做线上预测。
PMML 是数据挖掘的一种通用的规范,它用统一的 XML 格式来描述我们生成的机器学习模型。这样无论你的模型是 sklearn,R 还是 Spark MLlib 生成的,我们都可以将其转化为标准的 XML 格式来存储。当我们需要将这个 PMML 的模型用于部署的时候,可以使用目标环境的解析 PMML 模型的库来加载模型,并做预测。
可以看出,要使用 PMML,需要两步的工作,第一块是将离线训练得到的模型转化为 PMML 模型文件,第二块是将 PMML 模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。
4.2 实例代码实现
针对 CIC-IDS2017 数据集,sklearn 决策树机器学习模型保存为 pmml 文件实现代码:
import pandas as pd
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn2pmml import sklearn2pmml
# 加载数据
raw_data_filename = "data/clearData/total_expend.csv"
print("Loading raw data...")
raw_data = pd.read_csv(raw_data_filename, header=None,low_memory=False)
# 随机抽取比例
# raw_data=raw_data.sample(frac=0.03)
# 将非数值型的数据转换为数值型数据
raw_data[last_column_index], attacks = pd.factorize(raw_data[last_column_index], sort=True)
# 对原始数据进行切片,分离出特征和标签,第1~41列是特征,第42列是标签
features = raw_data.iloc[:, :raw_data.shape[1] - 1] # pandas中的iloc切片是完全基于位置的索引
labels = raw_data.iloc[:, raw_data.shape[1] - 1:]
# 数据标准化
# features = preprocessing.scale(features)
# features = pd.DataFrame(features)
# 将多维的标签转为一维的数组
labels = labels.values.ravel()
# 将数据分为训练集和测试集,并打印维数
df = pd.DataFrame(features)
X_train, X_test, y_train, y_test = train_test_split(df, labels, train_size=0.8, test_size=0.2, stratify=labels)
pipeline = PMMLPipeline([("classifier", DecisionTreeClassifier(criterion='entropy', max_depth=12, min_samples_leaf=1, splitter="best"))])
pipeline.fit(X_train, y_train)
sklearn2pmml(pipeline, "data/pmml/DecisionTreeIris.pmml", with_repr = True)
存储的 pmml 文件内容:
java 调用生成的 pmml 文件,并进行预测新数据代码:
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.evaluator.support_vector_machine.VoteDistribution;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.io.*;
/**
* 分类模型测试
*/
class PMMLDemo {
/**
* 加载模型
*/
private Evaluator loadPmml() {
PMML pmml = new PMML();
InputStream inputStream = null;
try {
// 读取resources文件夹下的pmml文件
Resource resource = new ClassPathResource("DecisionTreeIris.pmml");
inputStream = resource.getInputStream();
} catch (IOException e) {
e.printStackTrace();
}
if (inputStream == null) {
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException | JAXBException e1) {
e1.printStackTrace();
} finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
return modelEvaluatorFactory.newModelEvaluator(pmml);
}
/**
* 分类预测
*/
private void predict(Evaluator evaluator,Map<String, Double> featuremap) {
List<InputField> inputFields = evaluator.getInputFields();
System.out.println(inputFields);
// 从原始特征获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
// 特征名称
FieldName inputFieldName = inputField.getName();
// 特征值
Object rawValue = featuremap.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
// 预测结果
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.err.println("target: " + targetFieldName.getValue()
+ " value: " + targetFieldValue);
}
}
//读取csv文件
private static String readCSV(){
//第一步:先获取csv文件的路径,通过BufferedReader类去读该路径中的文件
File csv = new File("E:\\ideaProject\\javaRunPython\\src\\main\\resources\\1.csv");
String lineDta = "";
try{
//第二步:从字符输入流读取文本,缓冲各个字符,从而实现字符、数组和行(文本的行数通过回车符来进行判定)的高效读取。
BufferedReader textFile = new BufferedReader(new FileReader(csv));
//第三步:将文档的下一行数据赋值给lineData,并判断是否为空,若不为空则输出
while (textFile.readLine()!= null){
lineDta = textFile.readLine();
}
textFile.close();
}catch (FileNotFoundException e){
System.out.println("没有找到指定文件");
}catch (IOException e){
System.out.println("文件读写出错");
}
return lineDta;
}
public static void main(String args[]){
PMMLDemo demo = new PMMLDemo();
Evaluator model = demo.loadPmml();
Map<String, Double> data = new HashMap<String, Double>();
//读取测试数据(一行),并对其进行处理
String test = readCSV();
System.out.println(test);
String[] tests=test.split("\t");
for(int i=0;i<tests.length-1;i++){
data.put(""+i,Double.valueOf(tests[i]));
}
//将测试数据data放入模型中进行预测
demo.predict(model,data);
}
}
先说一下最终结果:
打印的时候,用的是红色 err 打印的(注意这里不是报错),显示的数据 data 对于各个类别预测概率,其中预测类别为 6 的概率最大为 99.4%(标红框)。
4.3 代码实例解释
对于 python 的代码这里不解释了,比较简单。
具体说一下 java 的,主要分为三部分:
(1)java 读取 pmml 文件并将其转换为 java 机器学习模型 model
简单来说,就是利用流读取 pmml 文件,然后 java 导入的 pmml 库将其实现为一个机器学习模块。
(2)读取 csv 测试数据,并进行数据处理
因为 CIC-IDS2017 数据集比较大(共有 80 个特征),所以不能手写,看网上的例子,一般是通过如下进行手动写的:
data.put("x1", 5.1);
data.put("x2", 3.5);
data.put("x3", 1.4);
data.put("x4", 0.2);
这里需要说一下为什么数据集是一个 map 形式,要知道在 python 里面,数据集就是单纯数据列(只有 value,没有 key),这是因为在 python 进行数据训练的时候,sklearn 的决策树模型会自动加一个 key,这个 key 就是数据的索引。
我再看网上的例子的时候,发现有一部分手写测试数据的时候,并没有写全,但是依然能够运行,这个是因为机器学习模型并没有全部应用特征数据,就比如上面的决策树,它的层数只有十几层,也就是用到了只是十几个特征,你只需要将这十几个特征写入数据集就行,前提是这十几个特征的索引你要写对,其实可以打印一下模型中的特征,也就是下面的 inputFields 的索引:
for (InputField inputField : inputFields) {
// 特征名称
FieldName inputFieldName = inputField.getName();
// 特征值
Object rawValue = featuremap.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
在对数据集进行处理的时候,曾遇到一个问题,就是读取 CSV 文件后获得一个字符,如何切分的问题。读取后的字符串是这个样子:
我一开始用单空格切分,发现不行,双空格也不行,打印的时候,发现字符串中有很多\t,于是就用\t 分割,果然正确,网上查了一下\t 表示
\t 是补全当前字符串长度到 8 的整数倍,最少 1 个最多 8 个空格,补多少要看你\t 前字符串长度。
看了这个解释,顿时明白了,csv 文件本身就是表格性质的,\t 相当于补全了。
(3)读取测试数据 data,进行预测
这一步理解起来就比较简单了,因为模型是决策树,从根节点出发,读取特征(inputFields),根据所需特征从 data 中读取出来,然后进行预测。
4.4 遇到的问题
遇到的问题主要是 java 这边的
(1)java-source1.5 中不支持 multi-catch 语句
解决方案参考:https://blog.csdn.net/qq_39793857/article/details/106925721
(2)Exception in thread “main“ java.lang.IllegalArgumentException: http://www.dmg.org/PMML-4_4 is not support
这是版本问题,要修改 pmml 文件的表头,具体参考:
https://blog.csdn.net/qq_32113189/article/details/107542225
OK,今天就到这里,更多精彩内容关注我的个人网站:蓝亚之舟博客。