keras的模型保存与加载
参考链接 keras保存模型中的save()和save_weights()
.
keras的模型保存
保存 模型时可以配合回调函数(callback)的方法,在每一次batch(应该是batch)中,保留最佳的模型参数,常见的模型类型有两种,一种是ckpt,另一种是h5。
h5类型
bast_model_filepath = './checkpoint/best_bilstm_crf_model.h5'
checkpoint = keras.callbacks.ModelCheckpoint(
bast_model_filepath,
monitor='val_loss',
verbose=1,
save_best_only=True,
mode='min'
)
model.fit(
x=train_X,
y=train_y,
batch_size=32,
epochs=80,
validation_data=(dev_X, dev_y),
shuffle=True,
callbacks=[reduce_lr,earlystop,checkpoint]
)
以上代码中,代码 bast_model_filepath = ‘./checkpoint/best_bilstm_crf_model.h5’ 确定了模型 保留的形式为.h5格式,根据参考链接所说的,.h5格式的文件既保存了图模型,也保存了权重参数,因此在加载模型的时候,不需要再把图模型建立一遍,直接加载即可。注意一点,如果搭建模型后还未训练模型,就保留了h5文件,此时h5文件中保留的应该只有图模型,没有参数。保存模型时,建议使用h5类型。
ckpt类型
图片截图自
链接: 北京大学TensorFlow2.0.
保存后的效果图为
声明一下,这个结果不是上面代码截图的训练结果,代码截图有一个 save_weights_only ,可能只保存了权重参数(.data),没有保存图模型。
通过参阅网上的其他博客,ckpt保存的模型有三个后缀类型。.index仅仅起到索引的作用,用于判断是否有ckpt模型。.meta存储的应该是图模型,.data模型的文件大小最大,存储的是权重参数。ckpt后面的数字是回调函数中,每隔多少次训练保存一次模型。
保存模型和加载模型
在保存模型的时候,用到的代码有 saver.save 的方法,需要先创建一个saver对象,这是很早版本的keras用到的命令,现在的命令一般为上图的 代码截图中的保存方式,也可以采用
model.save('m2.h5')
model.save_weights('m3.h5')
第一种保留的是图模型+权重参数,第二种只保留权重,因此在使用后者保存的模型的时候,需要先把网络结构,即图模型给搭建起来,再使用如下命令
model.load_weights('m3.h5')
采用save_weights,加载模型的时候也必须是load_weights。
要想加载第一种方式保存的模型,采用
model = load_model('m1.h5')
采用此种方式不必创建图模型,直接加载模型即可,极其简便。
ps:回到第一个代码片段,该片段后面其实还要两句代码
bast_model_filepath = './checkpoint/best_bilstm_crf_model.h5'
checkpoint = keras.callbacks.ModelCheckpoint(
bast_model_filepath,
monitor='val_loss',
verbose=1,
save_best_only=True,
mode='min'
)
model.fit(
x=train_X,
y=train_y,
batch_size=32,
epochs=80,
validation_data=(dev_X, dev_y),
shuffle=True,
callbacks=[reduce_lr,earlystop,checkpoint]
)
model.load_weights(bast_model_filepath)
model.save('./checkpoint/bilstm_crf_model.h5')
通过分析,我认为在回调函数中仅仅保留的是权重参数,这一点可以通过 model.load_weights(bast_model_filepath) 可以看出,而后一个 model.save(’./checkpoint/bilstm_crf_model.h5’) 是把模型 和参数一并保存在h5文件中。
但是分析callback回调函数的源码
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
惊奇地发现其默认为False,那么在回调函数中,有可能保存的就是图结构+权重模型,再往后翻,发现好像确实是如此
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
这儿原代码作者在回调函数后又保存了一次模型,其意思可能是加双重保险把。
ps:再做一点说明,model.save_weights 不能配合 model.load_model 使用,但是 model.load_weights 似乎可以和 model.save 连用。
作者水平有限,本博客仅供参考,如有错误,欢迎指正。