keras的模型保存与加载


参考链接 keras保存模型中的save()和save_weights()
.

keras的模型保存

保存 模型时可以配合回调函数(callback)的方法,在每一次batch(应该是batch)中,保留最佳的模型参数,常见的模型类型有两种,一种是ckpt,另一种是h5。

h5类型

    bast_model_filepath = './checkpoint/best_bilstm_crf_model.h5'
    checkpoint = keras.callbacks.ModelCheckpoint(
        bast_model_filepath, 
        monitor='val_loss', 
        verbose=1, 
        save_best_only=True,
        mode='min'
        )
    model.fit(
        x=train_X, 
        y=train_y, 
        batch_size=32, 
        epochs=80, 
        validation_data=(dev_X, dev_y), 
        shuffle=True, 
        callbacks=[reduce_lr,earlystop,checkpoint]
        )

以上代码中,代码 bast_model_filepath = ‘./checkpoint/best_bilstm_crf_model.h5’ 确定了模型 保留的形式为.h5格式,根据参考链接所说的,.h5格式的文件既保存了图模型,也保存了权重参数,因此在加载模型的时候,不需要再把图模型建立一遍,直接加载即可。注意一点,如果搭建模型后还未训练模型,就保留了h5文件,此时h5文件中保留的应该只有图模型,没有参数。保存模型时,建议使用h5类型。

ckpt类型

在这里插入图片描述
图片截图自
链接: 北京大学TensorFlow2.0.
保存后的效果图为
在这里插入图片描述
声明一下,这个结果不是上面代码截图的训练结果,代码截图有一个 save_weights_only ,可能只保存了权重参数(.data),没有保存图模型。

通过参阅网上的其他博客,ckpt保存的模型有三个后缀类型。.index仅仅起到索引的作用,用于判断是否有ckpt模型。.meta存储的应该是图模型,.data模型的文件大小最大,存储的是权重参数。ckpt后面的数字是回调函数中,每隔多少次训练保存一次模型。

保存模型和加载模型

在保存模型的时候,用到的代码有 saver.save 的方法,需要先创建一个saver对象,这是很早版本的keras用到的命令,现在的命令一般为上图的 代码截图中的保存方式,也可以采用

model.save('m2.h5')
model.save_weights('m3.h5')

第一种保留的是图模型+权重参数,第二种只保留权重,因此在使用后者保存的模型的时候,需要先把网络结构,即图模型给搭建起来,再使用如下命令

model.load_weights('m3.h5')

采用save_weights,加载模型的时候也必须是load_weights。

要想加载第一种方式保存的模型,采用

model = load_model('m1.h5')

采用此种方式不必创建图模型,直接加载模型即可,极其简便。


ps:回到第一个代码片段,该片段后面其实还要两句代码

    bast_model_filepath = './checkpoint/best_bilstm_crf_model.h5'
    checkpoint = keras.callbacks.ModelCheckpoint(
        bast_model_filepath, 
        monitor='val_loss', 
        verbose=1, 
        save_best_only=True,
        mode='min'
        )
    model.fit(
        x=train_X, 
        y=train_y, 
        batch_size=32, 
        epochs=80, 
        validation_data=(dev_X, dev_y), 
        shuffle=True, 
        callbacks=[reduce_lr,earlystop,checkpoint]
        )
    model.load_weights(bast_model_filepath)
    model.save('./checkpoint/bilstm_crf_model.h5')

通过分析,我认为在回调函数中仅仅保留的是权重参数,这一点可以通过 model.load_weights(bast_model_filepath) 可以看出,而后一个 model.save(’./checkpoint/bilstm_crf_model.h5’) 是把模型 和参数一并保存在h5文件中。
但是分析callback回调函数的源码

    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        super(ModelCheckpoint, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

惊奇地发现其默认为False,那么在回调函数中,有可能保存的就是图结构+权重模型,再往后翻,发现好像确实是如此

                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

这儿原代码作者在回调函数后又保存了一次模型,其意思可能是加双重保险把。


ps:再做一点说明,model.save_weights 不能配合 model.load_model 使用,但是 model.load_weights 似乎可以和 model.save 连用。

作者水平有限,本博客仅供参考,如有错误,欢迎指正。

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值