tennsorflow中的断点续训问题

本文介绍了断点续训的概念,即在模型训练过程中,通过保存中间状态以在后续训练中继续。重点讲解了如何使用TensorFlow的ModelCheckpoint回调函数进行模型参数的保存和加载,确保训练的连续性。通过示例展示了如何配置和使用ModelCheckpoint,以及如何在第二次运行时加载先前的训练信息,以实现断点续训,提高模型训练效率。
摘要由CSDN通过智能技术生成

1. 什么是断点续训问题

你可以这样想,当你训练模型需要很多epoch,但是你具体有多少个epoch才能达到你的标准,你把epoch设置高了不仅会加大资源的消耗而且很容易出现其他的一些问题。所以有人就想了,能不能把模型分多次跑,每一次在上一次的基础上继续训练,直到达到我们满意的效果,如果可以这样就存在一个问题,就是如何保存每一次的训练参数,这就是断点续训问题

2. 如何执行断点续训

2.1 tf.keras.callbacks.ModelCheckpoint()类

我们可以使用tensorflow提供的类tf.keras.callbacks.ModelCheckpoint()

tf.keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    **kwargs,
)
Docstring:     
Save the model after every epoch. # 每一个epoch之后保存参数,可以使用参数save_best_only=True来只保存最好的参数

`filepath` can contain named formatting options,
which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`).

For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.

经常用的几个参数:

filepath:保存模型文件的路径

save_best_only:只保存最好模型时的参数(True)

save_weights_only:如果为True,只保存模型参数信息(model.save_weights(filepath)); 如果为False:则保存整个模型(model.save(filepath)

save_freq:'epoch'或者是一个integer.当使用'epoch'时,模型会在每个epoch之后保存代码,当使用integer时,模型会在你指定的几次epoch之后开始保存模型。如果你使用了save_best_only,则这个参数不需要设置

2.2 开始断点续训

checkpoint_save_path = "./checkpoint/mnist.ckpt"  # 声明一个ckpt文件存储路径
if os.path.exists(checkpoint_save_path + '.index'):  # 这里判断是否之前已经存储了模型的训练信息,如果是,则为模型加载之前的参数
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

# save model after every epoch
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

第一次运行代码
在这里插入图片描述
第二次运行代码

在这里插入图片描述
观察准确率可以看到,第二次运行代码是在第一次的基础上运行的

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

InceptionZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值