tensorflow2的checkpoint恢复训练

假如我定义了一个网络进行训练:

import tensorflow as tf
import numpy as np

class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]


class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维(batch_size)以外的维度展平
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)

    def call(self, inputs):         # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)    # [batch_size, 784]
        x = self.dense1(x)          # [batch_size, 100]
        x = self.dense2(x)          # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output


num_epochs = 5
batch_size = 50
learning_rate = 0.001

model = MLP()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 实例化Checkpoint,取名为myAwesomeModel,设置保存对象为model
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
    if batch_index % 100 == 0:                              # 每隔100个Batch保存一次
            path = checkpoint.save('./save/model.ckpt')         # 保存模型参数到文件
            print("model saved to %s" % path)
batch 0: loss 2.399996
model saved to ./save/model.ckpt-1
batch 1: loss 2.191273
batch 2: loss 2.172761
batch 3: loss 2.080019
batch 4: loss 1.949049
batch 5: loss 1.927595
batch 6: loss 1.862676
batch 7: loss 1.937338
batch 8: loss 1.783551
batch 9: loss 1.693900
batch 10: loss 1.608254
batch 11: loss 1.595733
batch 12: loss 1.483589
batch 13: loss 1.745229
batch 14: loss 1.605927
batch 15: loss 1.411374
batch 16: loss 1.414417

......

这个时候每隔100个batch就报存了一次参数。假如说我们的电脑突然遇到故障了,下一次我不想再重头训练怎么办?这个时候就可以导入原先保存的最新的checkpoint再训练:

import tensorflow as tf
import numpy as np


class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]


class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维(batch_size)以外的维度展平
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)

    def call(self, inputs):         # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)    # [batch_size, 784]
        x = self.dense1(x)          # [batch_size, 100]
        x = self.dense2(x)          # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output


num_epochs = 3
batch_size = 5
learning_rate = 0.001
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)


model = MLP() # 实例化模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # myAwesomeModel,这是你原来保存的checkpoint时的model名字
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 恢复最新的checkpoint

for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))


WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).myAwesomeModel.dense2.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
batch 0: loss 0.023930
batch 1: loss 0.014747
batch 2: loss 0.005468
batch 3: loss 0.000008
batch 4: loss 0.000106
batch 5: loss 0.000138
batch 6: loss 0.000322
batch 7: loss 0.000636
batch 8: loss 0.000061

......

可以看到,它不是从头开始训练,loss的初始值就只有0.02了

当你要在测试集上测试的时候,也可以直接恢复之后使用:

model_to_be_restored = MLP()
# 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      
checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)
print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))
test accuracy: 0.975400

好了,就是这样,如果对您有帮助就点个赞吧。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: TensorFlow的checkpoint是一种保存模型参数的文件格式,它可以在训练过程中保存模型的参数,以便在需要时恢复模型的状态。checkpoint文件包含了模型的权重、偏置、梯度等参数,可以用于继续训练模型或者在其他设备上部署模型。在TensorFlow中,可以使用tf.train.Saver类来创建和加载checkpoint文件。 ### 回答2: TensorFlow是一个非常流行的机器学习框架,能够帮助数据科学家和开发人员快速开发、部署和管理机器学习模型。TensorFlow的模型保存和恢复机制是其最重要的特点之一,而在这个机制中,checkpoint文件起着至关重要的作用。 TensorFlow checkpoint是一个用于存储模型训练期间所有变量的数据结构,它包含整个 TensorFlow 图形的状态,包括各个张量的形状、数据类型和值等。简单来说,checkpoint文件就是一种二进制文件,通过它可以保存模型在训练过程中的中间状态,以便在需要时恢复模型继续训练、进行验证或推理等。 在使用 TensorFlow 训练模型时,checkpoint文件通常包含三个部分:checkpoint文件本身、一个标识最新checkpoint的文本文件和一个或多个用于表示训练步骤的整数值。这些文件通常存储在同一个目录下,并根据训练的进程和步骤进行命名,以便在需要时对它们进行访问和恢复。 TensorFlow Checkpoint提供了一种非常灵活的保存和恢复模型的机制,可以在不同的环境中使用,包括本地和分布式环境。它也可以与其他框架和工具集成,如TensorBoard、TensorFlow Serving和云平台等。此外,TensorFlow Checkpoint还提供了一些其他的高级特性,如变量共享、变量过滤、多项式捕捉等。这些特性可以帮助用户更方便地管理和调试大型模型。 总的来说,TensorFlow Checkpoint是一个非常重要的机制,可以使用户更好地管理、保存和恢复训练中的 Tensorflow 模型。通过使用 Checkpoint,用户可以更灵活、安全地对模型训练和测试的状态进行管理,从而使得模型能够在不同的场景中具有更好的性能和效果。 ### 回答3: TensorFlow checkpoint是一种用于保存模型参数的机制,当长时间训练模型时,我们往往希望能够保存模型参数,以便在必要时进行恢复或在新的任务上继续训练。 TensorFlow checkpoint将模型的所有可训练参数保存在一组二进制文件中,并使用索引文件来跟踪每个参数的最新值。这种机制允许我们将模型保存到磁盘中并稍后恢复它,以便进行推断或继续训练。 TensorFlow checkpoint的使用非常简单,只需使用`tf.train.Saver`类将模型参数保存到文件中。例如,以下代码演示了如何在每个epoch结束时保存模型: ```python saver = tf.train.Saver() with tf.Session() as sess: # 训练模型 # ... # 保存模型 saver.save(sess, "./model.ckpt", global_step=epoch) ``` 在上面的代码中,我们使用`saver.save()`方法将模型参数保存到名为`model.ckpt`的文件中,并将当前epoch数作为全局步数以确保每个文件的唯一性。稍后,我们可以在其他 TensorFlow 程序中加载模型并恢复所有参数: ```python saver = tf.train.Saver() with tf.Session() as sess: # 加载模型 saver.restore(sess, "./model.ckpt-100") # 在模型上进行推断或继续训练 # ... ``` 在恢复模型时,我们使用`saver.restore()`方法将之前保存的checkpoint文件加载到当前的 TensorFlow 会话中。请注意,我们需要指定全局步数以告诉 TensorFlow 我们希望恢复哪个checkpoint文件。 总而言之,TensorFlow checkpoint提供了一种优雅而简单的方式来保存和恢复模型参数。无论是进行模型推断还是继续训练,都会受益于它所提供的便利性和灵活性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值