Tensorflow 保存模型后恢复模型测试-2种方法:
1. 保存模型参数,恢复时,先重建网络,再倒入参数。 如下以mnist为例子eg:
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data
#下载mnist如果没有 则下载
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
def model(inputs):
conv1=slim.conv2d(inputs,32,3)
pool2=slim.max_pool2d(conv1,2,2,'SAME')
conv2=slim.conv2d(pool2,64,3)
pool2=slim.max_pool2d(conv2,2,2,'SAME')
print(pool2.get_shape().as_list())
re=tf.reshape(pool2,[-1,7*7*64])
output=slim.fully_connected(re,10,None)
return output
def train():
xs=tf.placeholder(tf.float32,(None,784))# 28*28
ys=tf.placeholder(tf.float32,(None,10))
x_image=tf.reshape(xs,[-1,28,28,1])
output=model(x_image)
tf.add_to_collection('output',output)#####赋别名######
cross_entropy=tf.reduce_mean(slim.nn.softmax_cross_entropy_with_logits(labels=ys,logits=output))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#交互器
sess=tf.Session()
#初始化 - 必须
sess.run(tf.global_variables_initializer())
saver=tf.train.Saver(max_to_keep=1)
for i in range(0,300):
batch_xs,batch_ys=mnist.train.next_batch(100)
feed_dict={xs: batch_xs,ys: batch_ys}
sess.run(train_step, feed_dict=feed_dict)
if i%50 == 0:
print(sess.run(cross_entropy,feed_dict=feed_dict))
saver.save(sess,save_path=r'/home/**/model/model.ckpt',global_step=i)
sess.close()
def test():
xs=tf.placeholder(tf.float32,(None,784))# 28*28 ##这里的占位符变量必须是唯一的,不能被后面的变量覆盖掉。。。
ys=tf.placeholder(tf.float32,(None,10))
x_image=tf.reshape(xs,[-1,28,28,1])
output=model(x_image)
sess=tf.Session()
saver=tf.train.Saver()
saver.restore(sess,r'/home/**/model/model.ckpt-250') # should add number after model
correct_prediction=tf.equal(tf.argmax(output,1),tf.argmax(ys,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
result=sess.run(accuracy,feed_dict={xs:mnist.test.images,ys:mnist.test.labels})
print(result)
sess.close()
def main():
trains=False # True训练,False测试。注意:测试时,需要关掉Ipython,自动重开。才可以成功!
if trains:
train()
else:
test()
if __name__=='__main__':
main()
2. 保存整体,先倒入保存的网络结构,再倒入参数
import tensorflow as tf
sess=tf.Session()
#import meta 结构保存 graph
new_saver=tf.train.import_meta_graph(r'/home/**/model/model.ckpt-250.meta')
#import weights bias
new_saver.restore(sess,save_path=r'/home/**/model/model.ckpt-250')
#obtain we need
output=tf.get_collection('output')[0] # 需要训练时,先tf.add_to_collection('output',output)#
graph=tf.get_default_graph() #导入默认图,即:前面的网络结构
#placeholder
image=graph.get_operation_by_name('image').outputs[0]
label=graph.get_operation_by_name('label').outputs[0]
correct_prediction=tf.equal(tf.argmax(output,1),tf.argmax(label,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print(sess.run(accuracy,feed_dict={image:mnist.test.images,label:mnist.test.labels}))
以上即为tensorflow的2种恢复模型方式!!!