java 调用训练好的keras模型,tensorflow Pb模型

13 篇文章 3 订阅
7 篇文章 0 订阅

这篇文章主要给初次想使用java 调用keras或者使用tensorflow训练好的模型,下面将详细介绍环境安装以及调用步骤。

假设你忘记了如何配置java环境,可以从第一部分看起,否则直接转入第二部分。

一、首先配置java环境

1、下载JDK1.8版本,并且安装

2、配置环境变量,本人安装路径为默认路径C:\ProgramFiles\Java\jdk1.8.0_131

3、在path里输入%HOME_JAVA%\bin;

4、测试,在cmd中输入java,javac


##############################################################################

##############################################################################

二、配置tensorflow for Java

这有两种方法,第一种在命令行中输入

1首先下载libtensorflow-1.4.0

https://www.tensorflow.org/versions/r1.4/install/install_java

严格按照官网安装方法,写好测试程序后,在cmd窗口中输入

1、  加载jar包

javac -cp libtensorflow-1.4.0-rc1.jar HelloTF.java

2、 加载.dll文件,dll存放在当前文件夹下jni文件夹下

java -cp libtensorflow-1.4.0-rc1.jar;.-Djava.library.path=jni HelloTF


上面这种方法我并不是特别喜欢,下面介绍在myeclipse中配置
##################################################

1、  导入jar包

2、  配置.dll文件

右击项目名称,按照下面截图方式,填入相应dll文件路径


此时环境已经配置完成,接下来第三部分将介绍java调用keras和tensorflow中的pb模型

三、java调用pb模型

如果你是使用keras训练的模型,首先要将 Keras模型转化成TensorFlow的pb模型,这个步骤很简单,只要获取输出节点就好,当然网上也有代码,也有公开的代码,代码如下:

#
#keras模型转化为tensorflow模型

# coding: utf-8

# In[ ]:


# Parse input arguments

# In[ ]:

import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store", 
                    dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store", 
                    dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store", 
                    dest='input_model_file', type=str, default='cnn_model_new.h5')
parser.add_argument('-output_model_file', action="store", 
                    dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store", 
                    dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store", 
                    dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store", 
                    dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store", 
                    dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store", 
                    dest='quantize', type=bool, default=False)

parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)
# initialize

# In[ ]:

from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K

output_fld =  args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
    args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)    
weight_file_path = str(Path(args.input_fld) / args.input_model_file)


# Load keras model and rename output

# In[ ]:

try:
    net_model = load_model(weight_file_path)
except ValueError as err:
    print('''Input file specified ({}) only holds the weights, and not the model defenition.
    Save the model using mode.save(filename.h5) which will contain the network architecture
    as well as its weights. 
    If the model is saved using model.save_weights(filename.h5), the model architecture is 
    expected to be saved separately in a json format and loaded prior to loading the weights.
    Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
          .format(weight_file_path))
    raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
    pred_node_names[i] = args.output_node_prefix+str(i)
    pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
    f = args.output_graphdef_file 
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)    
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
此时可以使用java代码调用pb模型,首先要保证输入的张量和你模型训练时的一样,否则预测效果肯定很差,下面是我写的文本分类程序,用java调用模型,预测每句话的类别,代码放在我的github上。https://github.com/lplping/tensorflow-for-java

  • 4
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值