巧用Keras 回调函数CallBacks

keras 回调函数官方文档:https://keras-cn.readthedocs.io/en/latest/other/callbacks/
主要就是包括:

  • ModelCheckpoint:模型检测,断点恢复训练
  • EasyStopping:提前终止
  • TensorBoard:训练可视化
ModelCheckPoint
先看一下参数吧
keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    period=1
)

1. filename:字符串,保存模型的路径
2. monitor:需要监视的值,val_acc或这val_loss
3. verbose:信息展示模式,0为不打印输出信息,1打印
4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
7. period:CheckPoint之间的间隔的epoch数
假如我们有这样一段代码
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

def create_model():
    model = Sequential([
        Dense(512, activation='relu', input_shape=(784,)),
        Dropout(0.2),
        Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Create a basic model instance
model = create_model()
model.summary()

test_checkpoint_path = "./weights/mnist.h5"
# Create checkpoint callback
checkpoint = ModelCheckpoint(test_checkpoint_path,
                             monitor='val_acc',
                             save_best_only=True,
                             save_weights_only=True,
                             verbose=1)
if os.path.exists(test_checkpoint_path):
     model.load_weights(test_checkpoint_path)
     print("\n Checkpoint loaded..............")

model.fit(train_images, train_labels, batch_size=32, epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[checkpoint])

当我们第一次跑这个程序的时候,会出现这样的提示,这就告诉我们开始讲最好的参数存入了.h5文件。
在这里插入图片描述
如果训练到第4个epoch电脑出现问题,训练暂停了。那么我们可以重新跑这个程序,会发现前面有一段答应输出,这就表示我们的model加载了上一次训练好的weights。

Checkpoint loaded..............

然后从暂停的位置重新开始训练。
这里我要对这个暂停的位置做一个特别说明:开始我天真的以为,当再次运行代码的时候,程序呢会从epoch=5再开始训练,d但是并不是这样子的,代码还是从epoch=1开始训练,但是会发现它开始输出的[loss, acc]和上次断点的位置相同。也就是说这个暂停是对[loss,acc]的暂停,所以我们第二次代码仍然需要跑10 epoch。
那么就有一个问题了这样时间还不是没有节约,断点训练即使再第二的训练中epoch=5就已经达到了最好的acc了,(假设10个epoch达到最佳精度)但是后面5个epoch还得接着跑啊!
所以这个时候我们就可以使用我们的EarlyStopping了啊!

EarlyStopping

参数:
monitor:需要监视的量

patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

verbose:信息展示模式

mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

上个代码只需作如下修改:

......
early_stopping = EarlyStopping(monitor='acc')
......
model.fit(train_images, train_labels, batch_size=32, epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[checkpoint, early_stopping])

TensorBoard

这个只知道对训练数据的可视化操作,了解不是很详细:
代码需改:

......
tensor_board = TensorBoard(log_dir='./log')
......
model.fit(train_images, train_labels, batch_size=32, epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[checkpoint, early_stopping, tensor_bpard])

然后打开终端:
输入

tensorboard --logdir=/full_path_to_your_logs

在浏览器输入:http://localhost:6006/
就可以看到:
在这里插入图片描述

参考链接:
https://machinelearningmastery.com/check-point-deep-learning-models-keras/
https://www.codetd.com/article/2343030
https://zhuanlan.zhihu.com/p/44854276
https://www.jianshu.com/p/321eb9d195cc

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值