ModelCheckpoint技术

在神经网络的训练学习过程中,常常需要把训练好的模型保存下来,ModelCheckpoint技术就是一种很实用的模型保存与改进方法。

在keras中通过回调API实现Checkpoint功能,本质上是callbacks的一个类。使用前需要从keras库中调用:

from kearas.callbacks import ModelCheckpoint

ModelCheckpoint的一般格式是:

checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', save_weights_only=False, period=5)

filename是保存的文件名(含路径)。
monitor是需要监测的值。
verbose是信息展示模式。
save_best_only, True or False. True, 保存训练集上性能最好的模型。
mode是模型评判准则。
save_weights_only, True or False. True, 只保存模型权重,否则保存整个模型。
period是checkpoint之间间隔的epoch数。

ModelCheckpoint主要有两个功能:
1、利用checkpoint改进模型。

filename = 'improvement-{epoch:02d}-{loss:.2f}.hdf5'
checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', period=5)
model.fit(x_train, y_train, epochs=500, batch_size=100, callbacks=[checkpoint], verbose=1, shuffle=False)

程序执行后,会保存一系列文件,我们可以通过这些文件了解模型的训练过程。
2、利用checkpoint获得最佳模型。

filename = 'improvement-best.hdf5'
checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', period=5)
model.fit(x_train, y_train, epochs=500, batch_size=100, callbacks=[checkpoint], verbose=1, shuffle=False)

程序执行后,文件名相同的文件依次覆盖,最后只有一个文件就是训练效果最好的模型。

那么如何加载一个已经训练好的模型呢?也有两种方法:
1、整模型加载。
如果前面不是只保存了权重的话,在这里是可以加载整个模型的,加载的方法也很简单:

from keras.models import load_model
load_model(filename/path)

2、仅加载权重。
我们也可以只加载权重。

model.load_weights(filename/path)

这样通过ModelCheckpoint技术,我们就可以实现模型的保存与改进。

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值