java调用python

前言

这一章来学习如何使用 java 调用 python 机器学习模块,毕竟 python 在算法方法好用,但是做 web 项目还是 java 更优,最近有个项目想要集成机器学习算法,这里简单记录一下(默认使用idea开发工具,默认都会创建maven项目,不会自行百度)。

java 调用 python,分三步来学习:

  • 第一步:java 调用 python 语句
  • 第二步:java 调用 python 脚本
  • 第三步:java调用python脚本函数(如何传递参数)
  • 第四步:java调用python机器学习模块并运行

上面三步都需要调用 jython 库,两种加入项目方法:

  • 官网下载 jar 包,手动加入
  • 在 pom 文件中直接加入依赖
    <dependency>
        <groupId>org.python</groupId>
        <artifactId>jython-standalone</artifactId>
        <version>2.7.0</version>
    </dependency>

1、java 调用 python 语句

首先要在idea中导入jython库,上面提到的两种方法,我使用的是第二种,即在 pom 文件中加入依赖,简单明了。

创建一个 javaRunPython 类,执行下列代码:

PythonInterpreter interpreter = new PythonInterpreter();
interpreter.exec("a=1+2; ");
interpreter.exec("print(a);");

运行结果如下:

在这里插入图片描述

另外,我发现一个有趣的现象,无论是"print a",还是"print(a)",居然都没有报错,也就是说 jython 兼容支持 python2 和 python3 两种语法。

2、java 调用 python 脚本

2.1 PythonInterpreter 调用 python 脚本

首先要写一个 python 脚本用来被调用,内容随意,下面是我写的:

hello = 'hello world, this is using java to pring python word'
print(hello)

然后在 maven 项目中,运行:

import org.python.util.PythonInterpreter;

PythonInterpreter interpreter = new PythonInterpreter();
interpreter.execfile("E:\\pythonTest.py");

运行结果如下:

在这里插入图片描述

上面调用的只是普通的 python 脚本,如果脚本中导入了第三方库,还能不能运行呢?测试一下,写一个简单的生成矩阵 python 脚本:

print("sdafd")
import numpy as np
n = np.arange(0, 30, 2)
n = n.reshape(3, 5)
print(n)
print("dafdafda")

运行结果:

在这里插入图片描述

由上图可以看到,只执行了第一句,当导入第三方库的时候报错了,这个是因为在 jython 库中不存在 numpy 模块,自然会报错。

由此可以得出:PythonInterpreter 可以简单执行普通的 python 脚本,但是对于带有第三方库的 python 脚本就不行了。

2.2 使用 Runtime 调用 python 脚本(推荐)

先来个普通脚本:

hello = 'hello world, this is using java to pring python word'
print(hello)

文件名为 Runtime.py

Runtime 方法的 java 代码如下:

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

Process proc;
try {
    proc = Runtime.getRuntime().exec("python E:\\Runtime.py");
    BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
    String line = null;
    while ((line = in.readLine()) != null) {
        System.out.println(line);
    }
    in.close();
    proc.waitFor();
} catch (IOException e) {
    e.printStackTrace();
} catch (InterruptedException e) {
    e.printStackTrace();
}

执行结果和上面一样

再来看看带有第三方库的脚本:

print("sdafd")
import numpy as np
n = np.arange(0, 30, 2)
n = n.reshape(3, 5)
print(n)
print("dafdafda")

因为文件名和位置没有变化,Runtime 的代码不变,再次执行,结果如下:

在这里插入图片描述

可以看到,完美执行,没有报错。

这里说一下为什么没有报错:

Runtime 执行本质类似于控制台调用 python.exe 去执行 python 脚本,也就是调用的 python 线程,python 仓库中本身是有 numpy 模块的,自然就不会报错了;

proc = Runtime.getRuntime().exec("python E:\\Runtime.py");

上方的代码中,可以看到 python 字样,这个是调用 python 程序的,只是 python 配置了环境变量,否则就要换成 D:\Python3.8.6\python.exe 了。

该方法有一个缺点,它得到 python 执行结果是通过数据流得到的,每次读取一行,相当于每执行一次就要读取一次结果,这就会导致我们运行的过程中很耗时。

3、java 调用 python 脚本函数(如何传递参数)

这次写一个带参数的 python 脚本——简单的两数相加:

def add(x,y):
   return x+y

脚本名称为 add.py,java 调用 python 函数代码为:

PythonInterpreter interpreter = new PythonInterpreter();
interpreter.execfile("E:\\RunTime.py");

// 第一个参数为期望获得的函数(变量)的名字,第二个参数为期望返回的对象类型
PyFunction pyFunction = interpreter.get("add", PyFunction.class);
int a = 5, b = 10;
//调用函数,如果函数需要参数,在Java中必须先将参数转化为对应的“Python类型”
PyObject pyobj = pyFunction.__call__(new PyInteger(a), new PyInteger(b));
System.out.println("the anwser is: " + pyobj);

运行结果:

在这里插入图片描述

4、java 调用 python 机器学习模型

java 调用 python 机器学习模型,我总结了一下共有四种:

  1. 利用上面的 java 调用 python 脚本——训练集和测试集写入文本中,python 脚本进行读取
  2. 将 python 训练的模型参数保存到文本中,用 java 代码重现模型的预测算法。这种工作量很大,而且出现的 bug 几率大大增加。最重要的是很多深度学习的框架就没办法用了。
  3. 使用 python 进程运行深度学习中训练的模型,在 java 应用程序中调用 python 进程提供的服务。这种方法没尝试过。python 语言写得程序毕竟还是在 python 环境中执行最有效率。而且 python 应用和 java 应用可以运行在不同的服务器上,通过进程的远程访问调用。
  4. 将机器学习模型保存为 pmml 文件,然后 java 调用 pmml 文件。这种方法是网上最常见的方法,进行上线部署的时候,不会依赖于 python 环境,推荐使用。

上面四种方法,前三种都需要依赖于 python 环境,如果要部署的系统中存在 python 环境那么使用前三种是可以的,如果没有,那么第四种方法是最优的。

下面对于 pmml 进行介绍和实例。

4.1 pmml 介绍

PMML:Predictive Model Markup Language 预测模型标记语言。data mining group 推出的,有十多年的历史了。是一种可以呈现预测分析模型的事实标准语言。标准东西的好处就是,各种开发语言都可以使用相应的包,把模型文件转成这种中间格式,而另外一种开发语言,可以使用相应的包导入该文件做线上预测。

PMML 是数据挖掘的一种通用的规范,它用统一的 XML 格式来描述我们生成的机器学习模型。这样无论你的模型是 sklearn,R 还是 Spark MLlib 生成的,我们都可以将其转化为标准的 XML 格式来存储。当我们需要将这个 PMML 的模型用于部署的时候,可以使用目标环境的解析 PMML 模型的库来加载模型,并做预测。

可以看出,要使用 PMML,需要两步的工作,第一块是将离线训练得到的模型转化为 PMML 模型文件,第二块是将 PMML 模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。

4.2 实例代码实现

针对 CIC-IDS2017 数据集,sklearn 决策树机器学习模型保存为 pmml 文件实现代码:

import pandas as pd
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn2pmml import sklearn2pmml

# 加载数据
raw_data_filename = "data/clearData/total_expend.csv"
print("Loading raw data...")
raw_data = pd.read_csv(raw_data_filename, header=None,low_memory=False)

# 随机抽取比例
# raw_data=raw_data.sample(frac=0.03)

# 将非数值型的数据转换为数值型数据
raw_data[last_column_index], attacks = pd.factorize(raw_data[last_column_index], sort=True)
# 对原始数据进行切片,分离出特征和标签,第1~41列是特征,第42列是标签
features = raw_data.iloc[:, :raw_data.shape[1] - 1]  # pandas中的iloc切片是完全基于位置的索引
labels = raw_data.iloc[:, raw_data.shape[1] - 1:]
# 数据标准化
# features = preprocessing.scale(features)
# features = pd.DataFrame(features)
# 将多维的标签转为一维的数组
labels = labels.values.ravel()

# 将数据分为训练集和测试集,并打印维数
df = pd.DataFrame(features)
X_train, X_test, y_train, y_test = train_test_split(df, labels, train_size=0.8, test_size=0.2, stratify=labels)

pipeline = PMMLPipeline([("classifier", DecisionTreeClassifier(criterion='entropy', max_depth=12, min_samples_leaf=1, splitter="best"))])
pipeline.fit(X_train, y_train)
sklearn2pmml(pipeline, "data/pmml/DecisionTreeIris.pmml", with_repr = True)

存储的 pmml 文件内容:

在这里插入图片描述

java 调用生成的 pmml 文件,并进行预测新数据代码:

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.evaluator.support_vector_machine.VoteDistribution;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.io.*;

/**
 *  分类模型测试
 */
class PMMLDemo {

    /**
     *  加载模型
     */
    private Evaluator loadPmml() {
        PMML pmml = new PMML();
        InputStream inputStream = null;
        try {
            // 读取resources文件夹下的pmml文件
            Resource resource = new ClassPathResource("DecisionTreeIris.pmml");
            inputStream = resource.getInputStream();
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (inputStream == null) {
            return null;
        }
        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException | JAXBException e1) {
            e1.printStackTrace();
        } finally {
            //关闭输入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        return modelEvaluatorFactory.newModelEvaluator(pmml);
    }

    /**
     *  分类预测
     */
    private void predict(Evaluator evaluator,Map<String, Double> featuremap) {

        List<InputField> inputFields = evaluator.getInputFields();
        System.out.println(inputFields);
        // 从原始特征获取数据,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            // 特征名称
            FieldName inputFieldName = inputField.getName();
            // 特征值
            Object rawValue = featuremap.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }
        // 预测结果
        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();
        for (TargetField targetField : targetFields) {
            FieldName targetFieldName = targetField.getName();
            Object targetFieldValue = results.get(targetFieldName);
            System.err.println("target: " + targetFieldName.getValue()
                    + " value: " + targetFieldValue);
        }

    }

    //读取csv文件
    private static String readCSV(){
        //第一步:先获取csv文件的路径,通过BufferedReader类去读该路径中的文件
        File csv = new File("E:\\ideaProject\\javaRunPython\\src\\main\\resources\\1.csv");
        String lineDta = "";
        try{
            //第二步:从字符输入流读取文本,缓冲各个字符,从而实现字符、数组和行(文本的行数通过回车符来进行判定)的高效读取。
            BufferedReader textFile = new BufferedReader(new FileReader(csv));
            //第三步:将文档的下一行数据赋值给lineData,并判断是否为空,若不为空则输出
            while (textFile.readLine()!= null){
                lineDta = textFile.readLine();
            }
            textFile.close();
        }catch (FileNotFoundException e){
            System.out.println("没有找到指定文件");
        }catch (IOException e){
            System.out.println("文件读写出错");
        }
        return lineDta;
    }

    public static void main(String args[]){
        PMMLDemo demo = new PMMLDemo();
        Evaluator model = demo.loadPmml();
        Map<String, Double> data = new HashMap<String, Double>();
        //读取测试数据(一行),并对其进行处理
        String test = readCSV();
        System.out.println(test);
        String[] tests=test.split("\t");
        for(int i=0;i<tests.length-1;i++){
            data.put(""+i,Double.valueOf(tests[i]));
        }
        //将测试数据data放入模型中进行预测
        demo.predict(model,data);
    }
}

先说一下最终结果:

在这里插入图片描述

打印的时候,用的是红色 err 打印的(注意这里不是报错),显示的数据 data 对于各个类别预测概率,其中预测类别为 6 的概率最大为 99.4%(标红框)。

4.3 代码实例解释

对于 python 的代码这里不解释了,比较简单。

具体说一下 java 的,主要分为三部分:

(1)java 读取 pmml 文件并将其转换为 java 机器学习模型 model

简单来说,就是利用流读取 pmml 文件,然后 java 导入的 pmml 库将其实现为一个机器学习模块。

(2)读取 csv 测试数据,并进行数据处理

因为 CIC-IDS2017 数据集比较大(共有 80 个特征),所以不能手写,看网上的例子,一般是通过如下进行手动写的:

data.put("x1", 5.1);
data.put("x2", 3.5);
data.put("x3", 1.4);
data.put("x4", 0.2);

这里需要说一下为什么数据集是一个 map 形式,要知道在 python 里面,数据集就是单纯数据列(只有 value,没有 key),这是因为在 python 进行数据训练的时候,sklearn 的决策树模型会自动加一个 key,这个 key 就是数据的索引。

我再看网上的例子的时候,发现有一部分手写测试数据的时候,并没有写全,但是依然能够运行,这个是因为机器学习模型并没有全部应用特征数据,就比如上面的决策树,它的层数只有十几层,也就是用到了只是十几个特征,你只需要将这十几个特征写入数据集就行,前提是这十几个特征的索引你要写对,其实可以打印一下模型中的特征,也就是下面的 inputFields 的索引:

 for (InputField inputField : inputFields) {
    // 特征名称
    FieldName inputFieldName = inputField.getName();
    // 特征值
    Object rawValue = featuremap.get(inputFieldName.getValue());
    FieldValue inputFieldValue = inputField.prepare(rawValue);
    arguments.put(inputFieldName, inputFieldValue);
}

在对数据集进行处理的时候,曾遇到一个问题,就是读取 CSV 文件后获得一个字符,如何切分的问题。读取后的字符串是这个样子:

在这里插入图片描述

我一开始用单空格切分,发现不行,双空格也不行,打印的时候,发现字符串中有很多\t,于是就用\t 分割,果然正确,网上查了一下\t 表示

\t 是补全当前字符串长度到 8 的整数倍,最少 1 个最多 8 个空格,补多少要看你\t 前字符串长度。

看了这个解释,顿时明白了,csv 文件本身就是表格性质的,\t 相当于补全了。

(3)读取测试数据 data,进行预测

这一步理解起来就比较简单了,因为模型是决策树,从根节点出发,读取特征(inputFields),根据所需特征从 data 中读取出来,然后进行预测。

4.4 遇到的问题

遇到的问题主要是 java 这边的

(1)java-source1.5 中不支持 multi-catch 语句

在这里插入图片描述

解决方案参考:https://blog.csdn.net/qq_39793857/article/details/106925721

(2)Exception in thread “main“ java.lang.IllegalArgumentException: http://www.dmg.org/PMML-4_4 is not support

在这里插入图片描述

这是版本问题,要修改 pmml 文件的表头,具体参考:
https://blog.csdn.net/qq_32113189/article/details/107542225

OK,今天就到这里,更多精彩内容关注我的个人网站:蓝亚之舟博客

  • 17
    点赞
  • 134
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值