本文介绍用于编写和读取检查点的 API
训练检查点指南
检查点可以捕获模型使用的所有参数(tf.Variable 对象)的确切值。检查点不包含对模型所定义计算的任何描述,因此通常仅在将使用保存参数值的源代码可用时才有用。
编写和读取检查点的 API。
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
tf.keras.Model.save_weights 可以保存一个 TensorFlow 检查点。
net.save_weights('easy_checkpoint')
编写检查点
TensorFlow 模型的持久状态存储在 tf.Variable 对象中。这些对象可以直接构造,但通常会通过像 tf.keras.layers 或 tf.keras.Model 这样的高级 API 创建。
管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。
tf.train.Checkpoint、tf.keras.layers.Layer 和 tf.keras.Model 的子类会自动跟踪分配给其特性的变量。下面的示例构造了一个简单的线性模型,然后编写检查点,其中包含该模型所有变量的值。
您可以使用 Model.save_weights 轻松保存模型检查点。
tf.train.Checkpoint
为了帮助演示 tf.train.Checkpoint 的所有功能, 下面定义了一个小数据集和优化步骤:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict