tensorflow模型保存后继续训练_TensorFlow 训练模型的保存&加载

什么是Tensorflow的模型

Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:

“.meta”:包含了计算图的结构

“.data”:包含了变量的值

“.index”:确认checkpoint

“checkpiont”:一个protocol buffer,包含了最近的一些checkpoints

存储一个Tensorflow的模型

当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。

saver.save(sess,'my_test_model')

sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。

具体的实例:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, './my_test_model')

sess.close()

执行上述语句,我们会同级目录下看到新增的文件:

my_test_model.data-00000-of-00001

my_test_model.index

my_test_model.meta

如果网络架构更改了,Tensorflow会重写上述的文件。

如果我们想要每1000步保存一次,那么需要更改语句:

saver.save(sess, 'my_test_model', global_step=1000)

那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
TensorFlow模型训练中断时,可以通过以下步骤继续训练: 1. 保存模型参数。在模型训练时,可以使用tf.train.Saver()保存模型参数。可以将模型参数保存到一个文件中。 2. 加载模型参数。在重新开始训练时,可以使用tf.train.Saver()从文件中加载之前保存模型参数。 3. 继续训练模型。使用加载模型参数继续进行训练。可以使用之前使用的优化器和损失函数。 下面是一个简单的示例代码,展示了如何保存加载模型参数,并继续训练模型: ``` import tensorflow as tf # 定义模型 x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y_pred = tf.nn.softmax(tf.matmul(x, W) + b) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1])) optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 训练模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys}) if i % 100 == 0: saver.save(sess, './model.ckpt') saver.save(sess, './model.ckpt') # 加载模型参数并继续训练 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, './model.ckpt') for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys}) if i % 100 == 0: saver.save(sess, './model.ckpt') ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值