14.1 使用保存点
长时间运行的程序需要能中途保存,加强健壮性。保存的程序应该可以继续运行,或者直接运行。深度学习的保存点用来存储模型的权重:这样可以继续训练,或者直接开始预测。
Keras有回调API,配合ModelCheckpoint
可以每轮保存网络信息,可以定义文件位置、文件名和保存时机等。例如,损失函数或准确率达到某个标准就保存,文件名的格式可以加入时间和准确率等。ModelCheckpoint
需要传入fit()
函数,也需要安装h5py
库。
14.2 效果变好就保存
好习惯:每轮如果效果变好就保存一下。还是用第7章的模型,用33%的数据测试。
每轮后在测试数据集上验证,如果比之前效果好就保存权重(monitor='val_acc', mode='max')。文件名格式是weights-improvement-val_acc=.2f.hdf5
。
from keras.callbacks import ModelCheckpoint
filepath="savemodel/weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True,mode='max')
callbacks_list = [checkpoint]
14.3 保存最好的模型
也可以只保存最好的模型:每次如果效果变好就覆盖之前的权重文件,把之前的文件名改成固定的就可以:
filepath="weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True,
mode='max')
callbacks_list = [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, nb_epoch=150, batch_size=10,
callbacks=callbacks_list, verbose=0)
14.4 导入保存的模型
保存点只保存权重,网络结构需要预先保存。
model.load_weights("weights.best.hdf5")