tensorflow 恢复模型并测试-2种方法

Tensorflow 保存模型后恢复模型测试-2种方法:

1. 保存模型参数,恢复时,先重建网络,再倒入参数。 如下以mnist为例子eg:

import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data
#下载mnist如果没有  则下载
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

def model(inputs):
    conv1=slim.conv2d(inputs,32,3)
    pool2=slim.max_pool2d(conv1,2,2,'SAME')
    conv2=slim.conv2d(pool2,64,3)
    pool2=slim.max_pool2d(conv2,2,2,'SAME')
    print(pool2.get_shape().as_list())
    re=tf.reshape(pool2,[-1,7*7*64])
    output=slim.fully_connected(re,10,None)
    return output
def train():
    xs=tf.placeholder(tf.float32,(None,784))# 28*28
    ys=tf.placeholder(tf.float32,(None,10))
    x_image=tf.reshape(xs,[-1,28,28,1])
    output=model(x_image)
    tf.add_to_collection('output',output)#####赋别名######
    cross_entropy=tf.reduce_mean(slim.nn.softmax_cross_entropy_with_logits(labels=ys,logits=output))
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    #交互器
    sess=tf.Session()
    #初始化 - 必须
    sess.run(tf.global_variables_initializer())
    saver=tf.train.Saver(max_to_keep=1)
  
    for i in range(0,300):
        batch_xs,batch_ys=mnist.train.next_batch(100)
        feed_dict={xs: batch_xs,ys: batch_ys}
        sess.run(train_step, feed_dict=feed_dict)
        if i%50 == 0:
            print(sess.run(cross_entropy,feed_dict=feed_dict))
            saver.save(sess,save_path=r'/home/**/model/model.ckpt',global_step=i)
    sess.close()

def test():
    xs=tf.placeholder(tf.float32,(None,784))# 28*28  ##这里的占位符变量必须是唯一的,不能被后面的变量覆盖掉。。。
    ys=tf.placeholder(tf.float32,(None,10))
    x_image=tf.reshape(xs,[-1,28,28,1])
    
    output=model(x_image)
    sess=tf.Session()
    saver=tf.train.Saver()
    saver.restore(sess,r'/home/**/model/model.ckpt-250') # should add number after model
    
    correct_prediction=tf.equal(tf.argmax(output,1),tf.argmax(ys,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    result=sess.run(accuracy,feed_dict={xs:mnist.test.images,ys:mnist.test.labels})
    print(result)
    sess.close()
    

def main():
    trains=False # True训练,False测试。注意:测试时,需要关掉Ipython,自动重开。才可以成功!
    if trains:
        train()
    else:
        test()

if __name__=='__main__':
    main()

 

2. 保存整体,先倒入保存的网络结构,再倒入参数

import tensorflow as tf

sess=tf.Session()

#import meta 结构保存 graph
new_saver=tf.train.import_meta_graph(r'/home/**/model/model.ckpt-250.meta')
#import weights bias
new_saver.restore(sess,save_path=r'/home/**/model/model.ckpt-250')

#obtain we need
output=tf.get_collection('output')[0] # 需要训练时,先tf.add_to_collection('output',output)#

graph=tf.get_default_graph() #导入默认图,即:前面的网络结构

#placeholder
image=graph.get_operation_by_name('image').outputs[0]
label=graph.get_operation_by_name('label').outputs[0]

correct_prediction=tf.equal(tf.argmax(output,1),tf.argmax(label,1))

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

print(sess.run(accuracy,feed_dict={image:mnist.test.images,label:mnist.test.labels}))

 

以上即为tensorflow的2种恢复模型方式!!!

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值