0、将神经网络生成pb文件,测试程序
以下是程序的关键代码,详细见连接
# 从训练好的ckpt中,导出pb文件
import fully_conected as model
import tensorflow as tf
def export_graph(model_name):
graph = tf.Graph()
with graph.as_default():
input_image = tf.placeholder(tf.float32, shape=[None,28*28], name='inputdata')
# 需要重写一下网络
logits = model.inference(input_image)
y_conv = tf.nn.softmax(logits,name='outputdata')
restore_saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
latest_ckpt = tf.train.latest_checkpoint('log')
restore_saver.restore(sess, latest_ckpt)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), ['outputdata'])
# tf.train.write_graph(output_graph_def, 'log', model_name, as_text=False)
with tf.gfile.GFile('log/mnist.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
export_graph('mnist.pb')
# 测试调用保存的pb 文件
from __future__ import absolute_import, unicode_literals
from datasets_mnist import read_data_sets
import tensorflow as tf
train,validation,test = read_data_sets("datasets/", one_hot=True)
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
output_graph_path = 'log/mnist.pb'
# sess.graph.add_to_collection("input", mnist.test.images)
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
tf.initialize_all_variables().run()
input_x = sess.graph.get_tensor_by_name("inputdata:0")
output = sess.graph.get_tensor_by_name("outputdata:0")
y_conv_2 = sess.run(output,{input_x:test.images})
print( "y_conv_2", y_conv_2)
# Test trained model
#y__2 = tf.placeholder("float", [None, 10])
y__2 = test.labels
correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y__2, 1))
print ("correct_prediction_2", correct_prediction_2 )
accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))
print ("accuracy_2", accuracy_2)
print ("check accuracy %g" % accuracy_2.eval())
1、tf.get_collection获取训练变量
# train_vars=tf.trainable_variables()
# g_vars=[var for var in train_vars if var.name.startswith('generator')]
# d_vars=[var for var in train_vars if var.name.startswith('discriminator')]
g_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
d_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
2、 tf.identity()
import tensorflow as tf
x = tf.Variable(1.0)
x_plus_1 = tf.assign_add(x, 1)
with tf.control_dependencies([x_plus_1]):
y = x
z=tf.identity(x,name='x')
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(5):
print(sess.run(z))
输出是:2,3,4,5,6
import tensorflow as tf
x = tf.Variable(1.0)
x_plus_1 = tf.assign_add(x, 1)
with tf.control_dependencies([x_plus_1]):
y = x
z=tf.identity(x,name='x')
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(5):
print(sess.run(y))
输出是:1,1,1,1,1