Tensorflow2.0 继续训练自己未训练完的模型(tf.train.Checkpoint())

在我们使用tensorflow做深度学习的时候,需要用大量的数据来训练模型。但正因为数据量大如果电脑的性能不是很好的话在训练模型的时候我们的电脑是没有剩余的内存供我们使用的,但模型训练又需要花费很多时间,如果我们需要用电脑做其他事情的话就必须停止训练模型,但停止以后再重新开始从头训练的话又会花费很多的时间,所以我们要在停止训练时保存的模型参数的那个阶段继续我们的训练。

模型保存

首先我们要知道要想继续我们的训练就必须保存好我们之前训练好的模型参数,这样我们的程序才能使用现有的参数继续来训练模型而不是再随机生成参数那种大范围的拟合。
这里我保存模型的方法是用的 tf.train.Checkpoint()这个函数。

checkpoint_path = './checkpoint/train'
ckpt = tf.train.Checkpoint(transformer=transformer,optimizer=optimizer)
# ckpt管理器
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

tf.train.Checkpoint()这个函数里面有两个参数一个是你要训练的模型,第二个是你的优化器。tf.train.CheckpointManager()函数里面有三个参数第一个是tf.train.Checkpoint()设定好你要保存的参数,第二个参数是你保存的路径,第三个参数是你要保存的模型数量。

当我从头开始训练的时候,这是训练后保存的模型参数和训练的准确度,稍后我们会用这个参数模型来演示中断后再开始训练的结果。
在这里插入图片描述

在这里插入图片描述

重新加载模型

checkpoint = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path1))

这两个函数就是重新加载我们的模型,需要注意的是tf.train.Checkpoint()里面的两个参数名要和你保存模型时的名字一样才行。

我们用上次保存下来的模型接着训练结果如下:
在这里插入图片描述
第一次从头开始训练模型时准确度为0.22,我们训练三个循环后保存模型时准确度是0.76。
当我们重新训练加载已有的模型时初次训练的准确度是0.78。正好接着上次保存后的准确度。

其实从头开始训练和接着上次的继续训练就只加了这两行代码。让模型迭代第一次时有初始的参数可以用。

checkpoint = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path1))
  • 11
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

辰溪0502

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值