Tensorflow2学习记录
提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Tensorflow2网络的保存与恢复后继续训练。
提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
最近计划深入学习一下机器学习,奈何笔记本不给力,训练起来超级慢。
又必须关机带着电脑走。晚上也不想把电脑设置为高性能不关机的模式,
那样太吵了。所以一直在找怎么将继续训练的代码写出来,这样我就
可以随时停止训练,然后背电脑回家或者可以用电脑做别的事情,没事了
就继续接着训练。这样虽然效率低一点,但这样才是常态啊。不是每个人
都有专门的机器学习服务器呀!
原来在matlab上用过机器学习,刚开始具体接触Python下的机器学习,
听说TensorFlow2比较好用,就从它入手。结果网络上铺天盖地的都是
TensorFlow1的帖子,好像找到点有关系的帖子,但是要订阅,每个月
几十上百块。都收钱了,博客还有啥意义?
提示:以下是本篇文章正文内容,下面案例可供参考
一、官方帮助在哪里?
最害怕看到最后还没有找到自己的答案。
所以下面直接给出我用的官方链接,您可以先看看,
如果看了之后不会,您老再向后看看我是怎么弄的。
官方帮助链接:https://tensorflow.google.cn/guide/checkpoint?hl=en
github实例链接:https://github.com/tensorflow/docs/blob/master/site/en/guide/checkpoint.ipynb
github上的链接可能不好下载,但偶尔还可以看。所以下面就主要根据github上的实例介绍恢复并重新训练的功能。
二、学习过程
首先,你应该准备好网络了。你可以随便找个网络比如网络上到处都有的MNIST手写数字识别的网络。
最开始我看的是简单粗暴TensofFlow2里面的一个例子。
链接地址:https://tf.wiki/zh_hans/basic/tools.html#tf-train-checkpoint
开始理解checkpoint
我们这里的目标是临时保存网络,等有空了再把网络恢复,然后重新接着训练,这里的接着是最重要的。
它给出的例子如下
import tensorflow as tf
import numpy as np
import argparse
from zh.model.mnist.mlp import MLP
from zh.model.utils import MNISTLoader
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', default='train', help='train or test')
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--batch_size', default=50)
parser.add_argument('--learning_rate', default=0.001)
args = parser.parse_args()
data_loader = MNISTLoader()
def train():
model = MLP()
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate