-
基础
TensorFlow 基础
TensorFlow 模型建立与训练
基础示例:多层感知机(MLP)
卷积神经网络(CNN)
循环神经网络(RNN)
深度强化学习(DRL)
Keras Pipeline
自定义层、损失函数和评估指标
常用模块 tf.train.Checkpoint :变量的保存与恢复
常用模块 TensorBoard:训练过程可视化
常用模块 tf.data :数据集的构建与预处理
常用模块 TFRecord :TensorFlow 数据集存储格式
常用模块 tf.function :图执行模式
常用模块 tf.TensorArray :TensorFlow 动态数组
常用模块 tf.config:GPU 的使用与分配 -
大规模训练与加速
TensorFlow 分布式训练
使用 TPU 训练 TensorFlow 模型 -
附录
强化学习基础简介
目录
Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel 。
tf.train.Checkpoint
很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle
存储 model.variables
。但不幸的是,TensorFlow 的变量类型 ResourceVariable
并不能被序列化。
好在 TensorFlow 提供了 tf.train.Checkpoint
这一强大的变量保存与恢复类,可以使用其 save()
和 restore()
方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer
、 tf.Variable
、 tf.keras.Layer
或者 tf.keras.Model
实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)
这里 tf.train.Checkpoint()
接受的初始化参数比较特殊,是一个 **kwargs
。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model
的模型实例 model
和一个继承 tf.train.Optimizer
的优化器 optimizer
,我们可以这样写:
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
这里 myAwesomeModel
是我们为待保存的模型 model
所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。
保存参数
接下来,当模型训练完成需要保存的时候,使用:
checkpoint.save(save_path_with_prefix)
就可以。 save_path_with_prefix
是保存文件的目录 + 前缀。
- 例如,在源代码目录建立一个名为 save 的文件夹并调用一次
checkpoint.save('./save/model.ckpt')
,我们就可以在 save 目录下发现名为checkpoint
、model.ckpt-1.index
、model.ckpt-1.data-00000-of-00001
的三个文件,这些文件就记录了变量信息。checkpoint.save()
方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。
载入之前保存的参数
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
即可恢复模型变量。 save_path_with_prefix_and_index
是之前保存的文件的目录 + 前缀 + 编号。
- 例如,调用
checkpoint.restore('./save/model.ckpt-1')
就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。
当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path)
这个辅助函数返回目录下最近一次 checkpoint 的文件名。
- 例如如果 save 目录下有
model.ckpt-1.index
到model.ckpt-10.index
的 10 个保存文件,tf.train.latest_checkpoint('./save')
即返回./save/model.ckpt-10
。
保存变量+恢复变量
总体而言,恢复与保存变量的典型代码框架如下:
# train.py 模型训练阶段
model = MyMod