如何将kears保存为pb格式并在java中调用

最近在做的工作要在java中使用keras训练好的模型,但是刚接触这方面的知识,于是在网上找了很相关的博客和资料去看,然后记录一下在这个过程中遇到的一些问题以及解决办法.

https://blog.csdn.net/Butertfly/article/details/80952987
关于在keras中如何对数据进行预处理/创建模型/训练模型/保存模型/在java和python中调用,可以参考这篇文章.

1.关于模型保存中输入输出节点名称的问题
我主要是参考上面引用的博客中的代码,但是我发现运行的时候一直报错"",即使复制的一模一样的代码这部分还是出错(自己将上述博客中的代码运行了一遍)

# kera 模型保存为pb文件
sess = K.get_session()
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["dense_3/Softmax"])
#第一种保存方法
with open('model.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())
# 第二中保存方法
tf.train.write_graph(frozen_graph_def, 'model', 'test_model.pb', as_text=False)
# 查看模型中输入输出节点的名称和大小
# print(model.input.name, model.input.shape)
# print(model.output.name, model.output.shape)

此部分需要注意的是output_node_names为模型输出节点的名称,如果写错云运行的时候会出现"*** is not in graph"的错误,因此可以使用# print(model.input.name, model.input.shape)和 print(model.output.name, model.output.shape)来查看,这两个节点名称在后续的python和java中调用会用到.

2.python中调用pb文件

import tensorflow as tf
import numpy as np
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    # 读取保存的pb文件,将其解析成对应的GraphDef Protocol Buffer
    with open('/home/thm/PycharmProjects/code/Extraction/Attacks_detection/step/model.pb', "rb") as f:
        output_graph_def.ParseFromString(f.read())
        # graph_def:将graph中保存的图片加载到当前图中,包含要导入到默认图中的操作的GraphDef protocol
        # name:放在graph_def名称前面的前缀,但是并不适用于作为输入的名称,默认的是import
        _ = tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)  # 初始化所有的变量
        a = np.array([6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2])
        # 获取输入输出节点的名称
        input_x = sess.graph.get_tensor_by_name("dense_1_input:0")
        output = sess.graph.get_tensor_by_name("dense_3/Softmax:0")
        # 传入想要计算的参数,将input_x的值替换(feed使用tensorflow值临时替换一个操作的输入参数,从而替换原来的输出结果)
        result = sess.run(output, feed_dict={input_x: a.reshape(1, 15)})
        print(result)
        Class_dict = {'BENIGN': 0, 'syn_flood': 1, 'icmp_flood': 2, 'udp_flood': 3, 'sarfu': 4}
        species_dict = {v: k for k, v in Class_dict.items()}
        # 代码中v:k代表了v是key,k是value,而k,v则是表示key,value换个位置
        print("\nPredicted species is: ")
        print(species_dict[np.argmax(result)])

3.java中调用pb文件
首先要添加依赖关系
方法一: 右击项目/Maven/Add Dependency,在出现的对话框中添加依赖关系(代码中的groupid/artifactid/version和对话框中的三部分一一对应).
方法二:打开项目中的pox.xml文件,将下面的代码直接粘贴过去.
特别注意:导入的tensorflow包的version对应的是你自己tensroflow的版本
(在终端下查看tensorflow版本号
python
import tensorflow as tf
(如果输入该命令之后提示ModuleNotFoundError: No module named ‘tensorflow’,应该是没有在tensorflow环境下运行python,因此先激活tensorflow
source activate tensorflow
再运行python

tf.version)

		<dependency>
			 <groupId>commons-io</groupId>
			 <artifactId>commons-io</artifactId>
			 <version>2.6</version>
		</dependency>
		<dependency>
			<groupId>org.tensorflow</groupId>
			<artifactId>libtensorflow</artifactId>
			<version>1.10.0</version>
		</dependency>
		<dependency>
	         <groupId>org.tensorflow</groupId>
	         <artifactId>proto</artifactId>
	         <version>1.10.0</version>
       </dependency>
       <dependency>
			 <groupId>org.tensorflow</groupId>
			 <artifactId>libtensorflow_jni</artifactId>
			 <version>1.10.0</version>
       </dependency>

这里插入图片描述

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.FloatBuffer;
import org.apache.commons.io.IOUtils;
import org.tensorflow.*;

public class Test_model {
    public static String PB_FILE_PATH = "pb模型保存的位置";
    public static String INPUT_TENSOR_NAME = "dense_1_input:0";//前面提到的输入输出节点的名称
    public static String OUTPUT_TENSOR_NAME = "dense_3/Softmax:0";
 
    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //导入图
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(PB_FILE_PATH));
            graph.importGraphDef(graphBytes);
            float[] a = new float[]{6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2};
            long[] shape = new long[]{1,15};
            Tensor<?> data = Tensor.create(shape, FloatBuffer.wrap(a));
            //根据图建立Session
            try (Session session = new Session(graph)) {
                //相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})
                Tensor<?> out = session.runner()
                        .feed(INPUT_TENSOR_NAME, data)
                        .fetch(OUTPUT_TENSOR_NAME).run().get(0);
                System.out.println(out);
            }
        }
    }
}

4.java中在原有项目添加完依赖关系后可能会出现"Exception in thread “main” java.lang.NoClassDefFoundError:"
此问题的具体解决方法可以参考这篇博客,讲的挺详细的https://blog.csdn.net/lz6363/article/details/82561292

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值