因为目前keras比较容易上手,很多网络现在都是用keras做训练,得到的权值结果保存在h5文件中。如果想要在工程部署上使用离线调用,一般都是转成pb文件,尤其在C++调用时,tensorflow有专门为C++调用的API,非常好用。
1. 此文生成的keras权值是基于keras2.1.0,python3.6.5。
2. 保存权值时使用的代码是:model.save()函数。
3. 定义h5文件转pb文件的函数:在python下完成
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
graph = session.graph
with graph.as_default():
input_graph_def = graph.as_graph_def()
frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names)
return frozen_graph
4. 调用freeze_session()函数
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
input_fld = './h5_weight/'
weight_file = 'k_g_27_image_0_c1.h5'
output_graph_name = 'k_g_27_image_0_c1.pb'
output_fld = './pb_weight/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0)
net_model = load_model(weight_file_path)
sess = K.get_session()
frozen_graph = freeze_session(sess, output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
在测试过程中 ,函数write_graph的参数as_text的值设置成true时,在后面读取pb文件时会出错。
5. 在python中调用转换的pb文件进行预测
import tensorflow as tf
from PIL import Image
import numpy as np
import os
def freeze_graph_test(pb_path, image_path, out_path):
'''
:param pb_path:pb文件的路径
:param image_data:测试图片的路径
:return:
'''
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name, '\n')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
input_image_tensor = sess.graph.get_tensor_by_name("input_1:0") # 定义输入的张量名称,对应网络结构的输入张量
output_tensor_name = sess.graph.get_tensor_by_name("conv2d_19/Sigmoid:0") # 定义输出的张量名称。
im = read_image(image_path) # 读取测试图片, 按照训练输入的数据格式准备
im=np.expand_dims(im, axis=3)
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: im})
image_path = "./data_predict_c1/test/images/7.jpg"
pb_path = "./pb_weight/k_g_27_image_0_c1.pb"
out_path = "./data_predict_c1/predict/"
if not os.path.exists(out_path):
os.makedirs(out_path)
freeze_graph_test(pb_path=pb_path, image_path=image_path, out_path=out_path)
以上就是keras的h5权值文件转tensorflow的pb文件的整个过程。