Tensorflow2.0之模型权值的保存与恢复(Checkpoint)

介绍

很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。

TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save() 和 restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。

保存变量

# train.py 模型训练阶段

model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件
checkpoint.save('./save/model.ckpt')

这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。在这里,我们取键名为 myModel,指定保存对象为 model。如果我们希望保存其他对象如 Optimizer 的参数,我们可以这样写:

checkpoint = tf.train.Checkpoint(myModel=model, myOptimizer=optimizer)

训练完后,checkpoint 文件会出现在 ‘./save/’ 文件夹下,‘model.ckpt’ 是这些文件的前缀。如果我们只调用了一次 checkpoint.save 函数,那么在 ‘./save/’ 文件夹下会出现名为 checkpoint 、 model.ckpt-1.index 、 model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个.index 文件和.data 文件,序号依次累加。

恢复变量

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。

# test.py 模型使用阶段

model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model_to_be_restored)  # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save'))  # 从文件恢复模型参数

当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。例如如果 save 目录下有 model.ckpt-1.index 到 model.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint(’./save’) 即返回 ./save/model.ckpt-10 。

有限制地保留 Checkpoint 文件

在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:

  • 在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;

  • Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 epoch 的编号作为文件编号)。

这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100) 。

实例

我们通过对 MNIST 数据集的训练来举例:

1、定义模型及训练过程

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

mnist = keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype(np.float32)
x_test = x_test[..., tf.newaxis].astype(np.float32)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(x_test.shape[0])

class MyModel(keras.Model):
    # Set layers.
    def __init__(self):
        super(MyModel, self).__init__()
        # Convolution Layer with 32 filters and a kernel size of 5.
        self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)
        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
        self.maxpool1 = layers.MaxPool2D(2, strides=2)

        # Convolution Layer with 64 filters and a kernel size of 3.
        self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)
        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
        self.maxpool2 = layers.MaxPool2D(2, strides=2)

        # Flatten the data to a 1-D vector for the fully connected layer.
        self.flatten = layers.Flatten()

        # Fully connected layer.
        self.fc1 = layers.Dense(1024)
        # Apply Dropout (if is_training is False, dropout is not applied).
        self.dropout = layers.Dropout(rate=0.5)

        # Output layer, class prediction.
        self.out = layers.Dense(10)

    # Set forward pass.
    def call(self, x, is_training=False):
        x = tf.reshape(x, [-1, 28, 28, 1])
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x, training=is_training)
        x = self.out(x)
        if not is_training:
            # tf cross entropy expect logits without softmax, so only
            # apply softmax when not training.
            x = tf.nn.softmax(x)
        return x

model = MyModel()

loss_object = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

2、保存模型参数

2.1 不限制 checkpoint 文件个数

EPOCHS = 5

checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    path = checkpoint.save('./save/model.ckpt')
    print("model saved to %s" % path)

2.2 限制 checkpoint 文件个数

EPOCHS = 5

checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)
for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    path = manager.save(checkpoint_number=epoch)
    print("model saved to %s" % path)

3、加载模型参数

model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      
checkpoint.restore(tf.train.latest_checkpoint('./save')) 
for test_images, test_labels in test_ds:
    y_pred = np.argmax(model_to_be_restored.predict(test_images), axis=-1)
    print("test accuracy: %f" % (sum(tf.cast(y_pred == test_labels, tf.float32)) / x_test.shape[0]))
test accuracy: 0.989600

参考资料

简单粗暴 TensorFlow 2

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

cofisher

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值