tensorflow2.4 用checkpoint保存网络,读取后继续训练!

Tensorflow2学习记录

提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Tensorflow2网络的保存与恢复后继续训练。


提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

最近计划深入学习一下机器学习,奈何笔记本不给力,训练起来超级慢。
又必须关机带着电脑走。晚上也不想把电脑设置为高性能不关机的模式,
那样太吵了。所以一直在找怎么将继续训练的代码写出来,这样我就
可以随时停止训练,然后背电脑回家或者可以用电脑做别的事情,没事了
就继续接着训练。这样虽然效率低一点,但这样才是常态啊。不是每个人
都有专门的机器学习服务器呀!

原来在matlab上用过机器学习,刚开始具体接触Python下的机器学习,
听说TensorFlow2比较好用,就从它入手。结果网络上铺天盖地的都是
TensorFlow1的帖子,好像找到点有关系的帖子,但是要订阅,每个月
几十上百块。都收钱了,博客还有啥意义?


提示:以下是本篇文章正文内容,下面案例可供参考

一、官方帮助在哪里?

最害怕看到最后还没有找到自己的答案。
所以下面直接给出我用的官方链接,您可以先看看,
如果看了之后不会,您老再向后看看我是怎么弄的。
官方帮助链接:https://tensorflow.google.cn/guide/checkpoint?hl=en
github实例链接:https://github.com/tensorflow/docs/blob/master/site/en/guide/checkpoint.ipynb

github上的链接可能不好下载,但偶尔还可以看。所以下面就主要根据github上的实例介绍恢复并重新训练的功能。

二、学习过程

首先,你应该准备好网络了。你可以随便找个网络比如网络上到处都有的MNIST手写数字识别的网络。
最开始我看的是简单粗暴TensofFlow2里面的一个例子。
链接地址:https://tf.wiki/zh_hans/basic/tools.html#tf-train-checkpoint

开始理解checkpoint

我们这里的目标是临时保存网络,等有空了再把网络恢复,然后重新接着训练,这里的接着是最重要的。

它给出的例子如下

import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoader

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()

def train():
    model = MLP()
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate
  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值