系列文章往期回顾
睿智的keras深度学习(零)——keras使用时可能遇到的参数及含义
睿智的keras深度学习(二)——函数式API构建顺序模型快速上手
使用已有的回调函数
检查点
训练模型时有时需要很长时间,这个过程中如果不加任何处理,一旦中断之前训练到一半的模型将不复存在,所以我们需要建立检查点来在模型训练的过程中对模型进行检查并保存
from keras.callbacks import ModelCheckpoint
checkpoint_cb = ModelCheckpoint("要保存的模型的名字.h5", save_best_only=True)
history = history = model.fit(X_train_A, y_train, epochs=20, validation_data=(X_valid, y_valid), callbacks=[checkpoint_cb])
在训练时设置检查点参数可以有效的保存模型
提前停止
训练模型时有时我们设置的轮次数会过多,在这个时候选用提前停止的回调函数可以令模型提前结束
from keras.callbacks import EarlyStopping
early_stopping_cb = EarlyStopping(patience=10, restore_best_weights=True)
history = history = model.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid), callbacks=[early_stopping_cb])
pd.DataFrame(history.history).plot()
plt.show()
自定义回调
利用以及写好的回调函数可能不足以满足我们的需求,所以我们可以利用Callback类的继承编写自己的自定义函数,这里给出一个示例,大家可以自己尝试一下
from keras.callbacks import Callback
class MyCallback(Callback):
def on_train_begin(self, logs={}):
print('on_train_begin')
def on_epoch_end(self, epoch, logs={}):
print('on_epoch_begin')
def on_train_end(self, logs={}):
print('on_train_end:', logs["val_loss"]/logs["loss"])
history = history = model.fit(X_train, y_train, epochs=20, validation_data=(X_valid, y_valid, callbacks=[MyCallback()])