pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,将图中的变量保存成为常量,便于调用,一般无法将pb文件映射成源代码。pb文件的能够保存tensorflow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图。
注:pb文件可以在训练时直接保存,也可以用.ckpt文件转化为.pb文件。
准备所需的文件:
1、训练完成的.index,.data,.mate文件。
2、编写.ckpt文件转.pb文件的代码。
output_node_names = ["Input/X_placeholder", "Inference/output"]#指定输入输出节点名
def freeze_pb(pb_file,ckpt_path):
with tf.name_scope('Input'):
input_data = tf.placeholder(dtype=tf.float32,shape=[None, 784], name='X_placeholder')
with tf.name_scope('Inference'):
# batch:20 输入:784,通道:1,输出:10
W = tf.Variable(initial_value=tf.random_normal(shape=[784,10], stddev=0.01), name='Weights')
b = tf.Variable(initial_value=tf.zeros(shape=[10]), name='bias')
print(W)
logits = tf.matmul(input_data, W) + b
pred = tf.nn.softmax(logits=logits,name='output')
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_path)
converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
input_graph_def=sess.graph.as_graph_def(),
output_node_names=output_node_names)
with tf.gfile.GFile(pb_file, "wb") as f:
f.write(converted_graph_def.SerializeToString())
注:
1、在将ckpt文件转为pb文件的时候,一定要将输入节点名称和输出节点名称与设计的网络中的输入节点名称和输出节点名称对应起来,否则会报错。例如这里将
Input/X_placeholder改为Input/X_placeholder_e时会有如下的报错。
AssertionError: Input/X_placeholdere_e is not in graph.
output_node_names = ["Input/X_placeholder", "Inference/output"]指定输入输出。
2、tf.graph_util.convert_variables_to_constants中的output_node_names变量来指定保存的节点名称而不是张量的名称,“Input/X_placeholder:0”是张量的名称而"Input/X_placeholder"表示的是节点的名称。在固化pb文件的时候用节点名称,在调用pb文件是使用张量名称。
执行上述代码后会在指定的目录下生成pb文件。
3、编写调用pb文件代码。
return_enement = ["Input/X_placeholder:0", "Inference/output:0"]#指定输入输出张量名称
def Load_PbFile(pb_file,image):
imagedata=np.array(image).reshape([1,784])
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.gfile.FastGFile(pb_file, 'rb') as f:
frozen_graph_def = tf.GraphDef()
frozen_graph_def.ParseFromString(f.read())
elements = tf.import_graph_def(frozen_graph_def,return_elements=return_enement)
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
prob=sess.run([elements[1]],feed_dict={elements[0]:imagedata})
tolist=list(prob[0][0])
classes = tolist.index(max(tolist))
print(classes)
这里return_enement = ["Input/X_placeholder:0", "Inference/output:0"]指定的是张量名称,区别于output_node_names = ["Input/X_placeholder", "Inference/output"],否则在运行时会出现TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a Operation into a Tensor.的错误。