keras 回调函数官方文档:https://keras-cn.readthedocs.io/en/latest/other/callbacks/
主要就是包括:
- ModelCheckpoint:模型检测,断点恢复训练
- EasyStopping:提前终止
- TensorBoard:训练可视化
ModelCheckPoint
先看一下参数吧
keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1
)
1. filename:字符串,保存模型的路径
2. monitor:需要监视的值,val_acc或这val_loss
3. verbose:信息展示模式,0为不打印输出信息,1打印
4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
7. period:CheckPoint之间的间隔的epoch数
假如我们有这样一段代码
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
def create_model():
model = Sequential([
Dense(512, activation='relu', input_shape=(784,)),
Dropout(0.2),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Create a basic model instance
model = create_model()
model.summary()
test_checkpoint_path = "./weights/mnist.h5"
# Create checkpoint callback
checkpoint = ModelCheckpoint(test_checkpoint_path,
monitor='val_acc',
save_best_only=True,
save_weights_only=True,
verbose=1)
if os.path.exists(test_checkpoint_path):
model.load_weights(test_checkpoint_path)
print("\n Checkpoint loaded..............")
model.fit(train_images, train_labels, batch_size=32, epochs=10,
validation_data=(test_images, test_labels),
callbacks=[checkpoint])
当我们第一次跑这个程序的时候,会出现这样的提示,这就告诉我们开始讲最好的参数存入了.h5
文件。
如果训练到第4个epoch
电脑出现问题,训练暂停了。那么我们可以重新跑这个程序,会发现前面有一段答应输出,这就表示我们的model加载了上一次训练好的weights。
Checkpoint loaded..............
然后从暂停的位置重新开始训练。
这里我要对这个暂停的位置做一个特别说明:开始我天真的以为,当再次运行代码的时候,程序呢会从epoch=5再开始训练,d但是并不是这样子的,代码还是从epoch=1开始训练,但是会发现它开始输出的[loss, acc]和上次断点的位置相同。也就是说这个暂停是对[loss,acc]的暂停,所以我们第二次代码仍然需要跑10 epoch。
那么就有一个问题了这样时间还不是没有节约,断点训练即使再第二的训练中epoch=5
就已经达到了最好的acc了,(假设10个epoch达到最佳精度)
但是后面5个epoch还得接着跑啊!
所以这个时候我们就可以使用我们的EarlyStopping了啊!
EarlyStopping
参数:
monitor:需要监视的量
patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
verbose:信息展示模式
mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。
上个代码只需作如下修改:
......
early_stopping = EarlyStopping(monitor='acc')
......
model.fit(train_images, train_labels, batch_size=32, epochs=10,
validation_data=(test_images, test_labels),
callbacks=[checkpoint, early_stopping])
TensorBoard
这个只知道对训练数据的可视化操作,了解不是很详细:
代码需改:
......
tensor_board = TensorBoard(log_dir='./log')
......
model.fit(train_images, train_labels, batch_size=32, epochs=10,
validation_data=(test_images, test_labels),
callbacks=[checkpoint, early_stopping, tensor_bpard])
然后打开终端:
输入
tensorboard --logdir=/full_path_to_your_logs
在浏览器输入:http://localhost:6006/
就可以看到:
参考链接:
https://machinelearningmastery.com/check-point-deep-learning-models-keras/
https://www.codetd.com/article/2343030
https://zhuanlan.zhihu.com/p/44854276
https://www.jianshu.com/p/321eb9d195cc