tensorflow保存和读取模型(通过图.meta)

由于我每过一段时间,去写模型的时候,就会忘记怎么保存和读取模型。于是,我写下这篇博客以用于自己做笔记。如果对大家有所帮助,那就感谢大家赏脸。如果哪里不足,还请大家评论告知。以弥补我自己的不足。

保存模型

保存模型很简单
两行代码就可以解决问题

x=tf.placeholder(tf.float32,[None,28,28,1],name='x')
y=tf.placeholder(tf.int64,[None],name='y')

# print(train_y)
y_=resnet(x,16,[3,3,3,3],10)
loss=tf.losses.sparse_softmax_cross_entropy(y,y_)

with tf.name_scope('train_op'):
    train=tf.train.AdamOptimizer().minimize(loss)

real=tf.argmax(y_,1)
corrent=tf.equal(real,y)
acc=tf.reduce_mean(tf.cast(corrent,tf.float64))
init=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
	sess.run(init)
	saver.save(sess,'model.resnet.ckpt')

上面很多代码省略,主要观看两行代码

saver=tf.train.Saver()
saver.save(sess,'model.resnet.ckpt')

于是乎,模型就保存完毕,在model文件夹生成了四个文件
mnist模型图
上图中,.meta文件夹就是我们模型的图了!

读取模型

有了模型,我们现在就开始读取模型,以便于我们做预测或进行迁移学习

model_path='model/'
saver=tf.train.import_meta_graph(model_path+'model.resnet.ckpt.meta')
with tf.Session() as sess:
	saver.restore(sess,tf.train.latest_checkpoint(model_path))
	graph = tf.get_default_graph()
	x=graph.get_tensor_by_name('x:0')
	y=graph.get_tensor_by_name('y:0')
	y_=graph.get_tensor_by_name('y_/BiasAdd:0')
	Y=sess.run(y_,feed_dict={x:X})

这样读取模型的好处,可以不必重新建立图结构。
如果不知道有的tensor的name,可以用一个方法,在里面,你找到name就可以了

    for op in graph.get_operations():
        print(op)

里面的东西很多,你需要耐心去找。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值