使用tensorflow过程中,训练结束后我们需要用到模型文件。有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练。这时候我们需要掌握如何操作这些模型数据。看完本文,相信你一定会有收获!
1 Tensorflow模型文件
我们在checkpoint_dir目录下保存的文件结构如下:
|--checkpoint_dir
| |--checkpoint
| |--MyModel.meta
| |--MyModel.data-00000-of-00001
| |--MyModel.index
1.1 meta文件
MyModel.meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。
1.2 ckpt文件
ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:
MyModel.data-00000-of-00001
MyModel.index
- 1
- 2
1.3 checkpoint文件
我们还可以看,checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model
2 保存Tensorflow模型
tensorflow 提供了tf.train.Saver
类来保存模型,值得注意的是,在tensorflow中,变量是存在于Session环境中,也就是说,只有在Session环境下才会存有变量值,因此,保存模型时需要传入session:
saver = tf.train.Saver()
saver.save(sess,"./checkpoint_dir/MyModel")
看一个简单例子:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess =