java交互式调用python程序

我的上一篇文章 Java 运行python程序 中提到了我想用java调用python进行机器学习,至于为什么需要java来调用,主要是因为评估模型的一个开源库是采用java实现的。

但是这之中还存在一个问题,那就是每次加载模型都需要很多时间,除非为是批处理所有数据,才可以只加载一次模型。而实际上,模拟过程中的数据可能不会同时出现,那么就会反复调用python文件,而每次重新调用就会加载一次模型,最终可能会影响评估的结果。

        String as[] = {"1,2,3", "1,2,3", "1,2,3"};
        System.out.println(pythonRun.run(pyPath, as));
        for (int i = 0; i < 3; i++){
        	String as[] = {"1,2,3"};
        	System.out.println(pythonRun.run(pyPath, as));
        }

以上两段代码运行结果相同,但是花费的时间却差很多,前者批处理,实际是一次加载模型的时间,而后者则是三次。

而假设我能够改变Python程序的设计,让它能够多次输入输出,就可以避免重新加载模型。而这就需要java的交互。

核心代码如下:
public class PythonRun{
	private String environment = "python";
    private String root = null;
    private String cache = "cache/";
    private boolean autoRemoveCache = true;
    
	public static class AResult{
        private PythonRun pythonRun;
        private Process process;
        private String path;
        private BufferedWriter out;
        private BufferedReader in;

        public AResult(PythonRun pythonRun, Process process, String path) throws UnsupportedEncodingException {
            this.pythonRun = pythonRun;
            this.process = process;
            this.path = path;

            out = new BufferedWriter(new OutputStreamWriter(process.getOutputStream()));
            in = new BufferedReader(new InputStreamReader(process.getInputStream()));
        }

        public void close(){
            if (pythonRun.autoRemoveCache && path != null)
                new File(path).delete();
            process.destroy();
        }

        public void input(String message){
            out.write(message+"\n");
            out.flush();
        }

        public String getResult() throws Exception{
            String line;
            StringBuilder result = new StringBuilder();
            do {
                line = in.readLine();
                result.append(line).append("\n");
            } while (in.ready());
            return result.toString();
        }
    }

    public AResult asyncRun(String path, String ...args) throws IOException {
        path = createNewPy(path);
        List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));  //设定命令行
        inputArgs.addAll(Arrays.asList(args));
        Process process = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));
        return new AResult(this, process, path);
    }
}
使用效果如下:

测试的java代码:

		String pyPath = "E:\\pythonProject\\MEC-Study\\src\\test\\testForComm.py"; //python文件路径
        String pyEnvironment = "E:\\Anaconda3\\envs\\MEC-Study\\python.exe";
        PythonRun pythonRun = new PythonRun();
        pythonRun.setEnvironment(pyEnvironment);
        pythonRun.setRoot("E:\\pythonProject\\MEC-Study\\src");
        PythonRun.AResult aResult = pythonRun.asyncRun(pyPath);
        aResult.input("a");
        System.out.println(aResult.getResult());
        aResult.input("b");
        System.out.println(aResult.getResult());
        aResult.input("exit");
        System.out.println(aResult.getResult());
        aResult.close();

测试的python的代码:

if __name__ == '__main__':
    message = input()
    while message != "exit":
        print(message)
        message = input()

运行结果:

a

b

null


Process finished with exit code 0

所有代码:
import java.io.*;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public class PythonRun {
    private String environment = "python";
    private String root = null;
    private String cache = "cache/";
    private boolean autoRemoveCache = true;

    public String run(String path, String ...args) throws IOException {
        path = createNewPy(path);
        List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));  //设定命令行
        inputArgs.addAll(Arrays.asList(args));
        StringBuilder result = new StringBuilder();
        try {
            Process proc = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));  //执行py文件
            BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
            String line;
            while ((line = in.readLine()) != null) {
                result.append(line).append("\n");
            }
            in.close();
            proc.waitFor();
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (autoRemoveCache && path != null)
            new File(path).delete();
        return result.toString();
    }

    public static class AResult{
        private PythonRun pythonRun;
        private Process process;
        private String path;
        private BufferedWriter out;
        private BufferedReader in;

        public AResult(PythonRun pythonRun, Process process, String path) {
            this.pythonRun = pythonRun;
            this.process = process;
            this.path = path;

            out = new BufferedWriter(new OutputStreamWriter(process.getOutputStream()));
            in = new BufferedReader(new InputStreamReader(process.getInputStream()));
        }

        public void close() {
            if (pythonRun.autoRemoveCache && path != null)
                new File(path).delete();
            process.destroy();
        }

        public void input(String message) throws IOException {
            out.write(message+"\n");
            out.flush();
        }

        public String getResult() throws Exception{
            String line;
            StringBuilder result = new StringBuilder();
            do {
                line = in.readLine();
                result.append(line).append("\n");
            } while (in.ready());
            return result.toString();
        }
    }

    public AResult asyncRun(String path, String ...args) throws IOException {
        path = createNewPy(path);
        List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));  //设定命令行
        inputArgs.addAll(Arrays.asList(args));
        Process process = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));
        return new AResult(this, process, path);
    }

    private String createNewPy(String path) {
        File file = new File(path);
        if (file.isFile()){
            String result = loadTxt(file);
            if (root != null){
                result = "import sys\n" +
                        "sys.path.append(\"" + root + "\")\n" + result;
            }
            String save = cache + file.getName();
            saveTxt(save, result);
            return save;
        }
        return null;
    }

    private static File saveTxt(String filename, String string){
        File file = new File(filename);
        try {
            BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),"UTF-8"));
            out.write(string);
            out.flush();
            out.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return file;
    }
    private String loadTxt(File file){
        StringBuilder result = new StringBuilder();
        try {
            BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8"));
            String str;
            while ((str = in.readLine()) != null) {
                result.append(str).append("\n");
            }
        }catch (Exception e){
            e.printStackTrace();
        }
        return result.toString();
    }

    public String getCache() {
        return cache;
    }

    public void setCache(String cache) {
        this.cache = cache;
    }

    public String getEnvironment() {
        return environment;
    }

    public void setEnvironment(String environment) {
        this.environment = environment;
    }

    public String getRoot() {
        return root;
    }

    public void setRoot(String root) {
        this.root = root;
    }

    public boolean isAutoRemoveCache() {
        return autoRemoveCache;
    }

    public void setAutoRemoveCache(boolean autoRemoveCache) {
        this.autoRemoveCache = autoRemoveCache;
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值