网络保存
- pb 文件:pb文件的网络结构和权重都是固化的不可更改
- meta等文件: 文件读入后可以再重新训练。
1. pb文件固化-python
为了实现用python训练网络,用C++直接运行网络,因此需要将网络固化成pb文件输出。
我的pb固化是在前向测试的时候实现的。代码如下
with tf.name_scope("Input"): #
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths],
name='input_image') # 由于图像存储的原因,灰度图维度较少一维
# is_training = tf.placeholder(tf.bool, name='is_training')
with tf.name_scope("net"):
out = net.inference(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=1.0,
is_training=False)
# 将输出结果进行softmax分布,再求最大概率所属类别
with tf.name_scope('Output'):
score = tf.nn.softmax(out,name='predict')
class_id = tf.argmax(score, 1)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, models_path)
images_list=glob.glob(os.path.join(image_dir,'*.bmp'))
for i in range(len(images_list)):
processed_img = read_image(images_list[i], resize_height, resize_width, normalization=True)
input_imge = sess.run(processed_img)
input_imge = input_imge[np.newaxis, :]
pre_score,pre_label = sess.run([score,class_id], feed_dict={input_images:input_imge})
max_score=pre_score[0,pre_label]
file_logger.info("![]({0}){1} pre labels:{2} score: {3}{4}".format(images_list[i]," \r",labels[pre_label], max_score," \r"))
if is_savePb:
##TODO
##存成pbfile, 注意必须为List才能存
outPut_nodeName = []
outPut_nodeName.append('Output/predict') #网络输出的节点
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names=outPut_nodeName # The output node names are used to select the usefull nodes
)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
sess.close()
需要注意的点是:
- input 和output要定义好,且input要用placeholder
- output节点的节点名前面要带上图名,而且outPut_nodeName的类型必须为List,即使只有一个节点
- 要记住输入输出的节点名,恢复的时候也是用同样的名字
2. pb文件恢复-python版本
代码如下:
def Predict(image_path, PbFileName, resize_height,resize_width):
"""
测试使用pbFile,单次读取一张图片并输出结果,
注意要给定输入节点的名称,用sess.graph.get_tensor_by_name("input:0"),节点格式为"nodeName:index"
:param image_path: 图像路径
:param PbFileName: pb图像路径
:param resize_height: 图像高度
:param resize_width: 图像宽度
:return:
"""
#读取pbFile
labels_filename = 'dataset/label.txt'
output_graph_def = tf.GraphDef()
labels = np.loadtxt(labels_filename, str, delimiter='\t')
with open(PbFileName, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
#TODO REDEFINE THE INPUT TENSOR IF NESSASARY
# 定义输入的张量名称,对应网络结构的输入张量
# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("Input/input_image:0")
output_tensor_name = sess.graph.get_tensor_by_name("Output/predict:0")
input_imge = read_image(image_path, resize_height, resize_width, normalization=True)
input_imge = sess.run(input_imge)
input_imge = input_imge[np.newaxis, :]
score = sess.run(output_tensor_name, feed_dict={input_image_tensor: input_imge
})
print("score:{}".format(score))
class_id = tf.argmax(score, 1)
print( "pre class is :{}".format(labels[sess.run(class_id)]))
需要注意的点是
- 使用sess.graph.get_tensor_by_name(“Input/input_image:0”)的时候除了节点名还要加上":0",使用C++时候不用
3. pb文件导入-C++版本
//图的导入
Status LoadGraph(const string& graph_file_name,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::GraphDef graph_def;
Status load_graph_status =
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return Status::OK();
}
//定义session
std::unique_ptr<tensorflow::Session> session;
string graph_path = tensorflow::io::JoinPath(root_dir, graph);
Status load_graph_status = LoadGraph(graph_path, &session); //输入图
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
return -1;
}
//定义输入输出
std::vector<Tensor> resized_tensors;
string image_path = tensorflow::io::JoinPath(root_dir, image);
Status read_tensor_status =
ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
input_std, &resized_tensors);
if (!read_tensor_status.ok()) {
LOG(ERROR) << read_tensor_status;
return -1;
}
const Tensor& resized_tensor = resized_tensors[0];
// Actually run the image through the model.
std::vector<Tensor> outputs;
Status run_status = session->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);