(1)将保存的模型文件解析为GraphDef :graph_def.ParseFromString(gfile.FastGFile(“model.pb”,‘rb’).read()),这里分为两步,一:通过gfile.FastGFile(“model.pb”,‘rb’).read()获取保存的pb模型对象并读取文件,二: graph_def = tf.GraphDef()创建graph_def对象,然后 graph_def.ParseFromString()解析为二进制放进graph_def对象中。
(2)导入我们上一步创建的图为默认图,这里我们需要指定张量的名称而不是节点的名称: 1. tf.import_graph_def(graph_def,return_elements=[“add:0”]) , 2. tf.get_default_graph() 获取默认图 。此处"add:0"为张量名称。
(3)现在我们就可以通过get_tensor_by_name方法来获取tensor,执行我们的pb模型了: sess.graph.get_tensor_by_name(“input:0”)
简言之:
1.读pb文件,放进图中
2.获取图
3.get_tensor获取节点
pb_predict_test.py文件的两段代码
此文件主要是对train_and_save_model.py保存的模型的调用
1.调用pb代码
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
_ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
graph = tf.get_default_graph()# 3.获得当前图
# 4.get_tensor_by_name获取需要的节点
x = graph.get_tensor_by_name("input_placeholder/x:0")
keep_prob = graph.get_tensor_by_name("drop_out/drop_out_placeholder:0")
y_out = graph.get_tensor_by_name("fc2/y_out:0")
img = io.imread(jpg_path)
img = transform.resize(img, (28, 28, 1))
#执行
test_y_out = sess.run(y_out, feed_dict={
x:np.reshape(img, [-1,784]),keep_prob: 1.0})
# print("test_y_out:{}".format(test_y_out))
prediction_labels = np.argmax(test_y_out, axis=1)
print("prediction_labels:{}".format(prediction_labels))
recognize("C:/Users/yangsunao/Pictures/Camera Roll/123.jpg", "graph1.pb")
2.调用训练好的ckpt代码
def ckpt(jpg_path):
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model\\' + 'model.meta')
saver.restore(sess, tf.train.latest_checkpoint('model\\'))
graph = tf.get_default_graph()
# one operation possibly have many outputs, so you need specify the which output, such as "name:0"
x = graph.get_tensor_by_name("input_placeholder/x:0")
keep_prob = graph.get_tensor_by_name("drop_out/drop_out_placeholder:0")
y_out = graph.get_tensor_by_name("fc2/y_out:0")
img = io.imread(jpg_path)
img = transform.resize(img, (28, 28, 1))
feed_dict={
x:np.reshape(img, [-1,784]),keep_prob: 1.0}
test_y_out = sess.run(y_out, feed_dict=feed_dict)
prediction_labels = np.argmax(test_y_out, axis=1)
print("CKPT_prediction_labels:{}".format(prediction_labels))
ckpt("C:/Users/yangsunao/Pictures/Camera Roll/123.jpg")
train_and_save_model.py
此段代码包括了模型的定义、训练、测试和保存为ckpt模型,以及把ckpt转为pb的代码
#!/usr/bin/env python
# -*- coding: utf