什么是Tensorflow的模型
Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:
“.meta”:包含了计算图的结构
“.data”:包含了变量的值
“.index”:确认checkpoint
“checkpiont”:一个protocol buffer,包含了最近的一些checkpoints
存储一个Tensorflow的模型
当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。
saver.save(sess,'my_test_model')
sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。
具体的实例:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()
执行上述语句,我们会同级目录下看到新增的文件:
my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta
如果网络架构更改了,Tensorflow会重写上述的文件。
如果我们想要每1000步保存一次,那么需要更改语句:
saver.save(sess, 'my_test_model', global_step=1000)
那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新