tensorflow模型的保存与恢复

本文详细介绍了Tensorflow模型的保存与恢复过程。包括元数据图(.META文件)和检查点文件(checkpoint file)的作用,以及如何使用tf.train.Saver()进行模型保存和恢复。还讲解了如何指定保存的变量、查看模型参数、导入预训练模型并进行扩展,以及处理模型中的占位符数据。
摘要由CSDN通过智能技术生成

1. 什么是Tensorflow模型?

https://my.oschina.net/u/2272631/blog/1556094

训练神经网络后,需要将它保存以备将来使用和部署到生产环境。那么什么是tensorflow模型?  

Tensorflow模型主要包含NN的网络图和已经训练好的变量参数,因此,Tensorflow模型有两个主要的文件:

  1. 元数据图(meta graph

这是一个协议缓冲区,它保存了tensorflow完整的网络图结构,即所有变量、操作、集合等。这个文件以.META为扩展名。

  1. 检查点文件(checkpoint file

这是一个二进制文件,它包含了所有的权重变量,biases变量和其他变量的值。在0.11版本之后,包含三个文件

checkpoint

my_model.data-00000-of-00001

my_model.index

.data文件是包含我们训练变量的文件;与此同时,Tensorflow也有一个名为checkpoint的文件。它只是不断的保存最新的检查点文件的记录。

 

2. 保存Tensorflow模型

训练图像分类的CNN,作为一种标准的做法,在训练模型时,需要一直关注着模型的损失函数和模型的准确度。一旦发现网络已经收敛,就可以停止训练。训练完成后,希望将所有的变量和网络模型保存下来,供以后使用。在tensorflow中,使用tf.train.Saver() 来保存NN的网络结构图和相关变量。

saver = tf.train.Saver()

需要注意,tensorflow变量的作用范围是在一个session里面。在保存模型的时候,应该在session里面通过save方法保存。

saver.save(sess, 'my-test-model')

一个完整的模型保存:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, 'saved_model/my_test_model')

其中,saved_model是保存模型的文件路径,my_test_model是模型的名称,文件保存结果示意图:

 

如果希望在迭代

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值