深度学习模型权重参数的保存和加载(包括选择最优权重)---基于tensorflow


  我们创建好模型之后需要保存模型,以方便后续对模型的读取与调用,保存模型我们可能有下面三种需求:1、 只保存模型权重参数;2、 同时保存模型图结构与权重参数;3、 在训练过程的检查点保存模型数据。下面分别对这三种需求进行实现

1. 只保存模型权重参数

  • Model.save_weights(file_path)  # 将文件保存到save_path
  • Model.load_weights(file_path)   # 将文件读取到save_path

应用场景如下:

注意保存路径的设置是./weight/ 后面的/千万不要少,要不然可能会出现保存路径的问题

在这里插入图片描述
在前面模型已经compile完毕,在训练结束后,保存一下模型的参数。保存的文件夹是save_path的文件夹,要提前设置好,并且是空的,保存完毕后在save_path文件夹下会生成以下三个文件

在这里插入图片描述

file_path可以设置为主文件mnist.py文件夹下的空文件夹,比如上图中的weight用于保存模型参数。
在使用的时候,将模型结构设置为相同后,直接加载参数:Model.load_weights(file_path) 即可使用,但是注意模型结构要相同且已经经过compile才可以,要不然会报错。等于只是省略了model.fit的过程.

2. 保存模型结构和权重参数

  • 保存模型
    model.save(‘net_model.h5’)

  • 模型加载
    new_model=tf.keras.models.load_model(‘net_model.h5’)

Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

保存完整的模型非常有用,使我们可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。
  
所以,我们保存整个模型的时候,保存文件的后缀一般都是.h5

应用场景如下:

# 训练模型
    save_path = 'net_model.h5'
    if os.path.exists(save_path) == False:
        # 优化器
        adam_optimizer = tf.keras.optimizers.Adam(learning_rate, )
        # 编译模型
        model.compile(optimizer=adam_optimizer,
                      loss=tf.keras.losses.sparse_categorical_crossentropy,
                      metrics=['acc'])
        # 模型开始训练
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test))
        model.save('net_model.h5')
    else:
        model = tf.keras.models.load_model('net_model.h5')

 可以看到在与主文件相同的文件夹下产生了一个以.h5结尾的文件net_model.h5,这个就是保存的完整的模型及参数。一旦这个文件存在的话,那么就可以直接加载这个模型及其参数,不需要训练。如下第二个图,可以加载到已经存下来的网络结构,和上轮训练的最后一轮的参数
在这里插入图片描述
在这里插入图片描述
在这里补充一下,判断一个文件是否存在的方法 以及 判断一个文件夹是否为空的方法:

  • 判断一个文件是否存在的方法os.path.exists(test_file.txt)=False说明该文件夹不存在

if os.path.exists(test_file.txt) == False # 该文件不存在的情况下
print(‘目标文件不存在’)

  • 判断一个文件夹是否为空:len(os.listdir(tar_dir)==0说明改文件夹下为空

if len(os.listdir(tar_dir)) == 0: # 目标文件夹内容为空的情况下
print(“目标文件夹为空”)

3. 在训练过程中检查点checkpoint保存模型权重参数或者结构

tf.keras.callbacks.ModelCheckpoint(参数如下)

  • filepath:string,保存模型文件的路径。
  • monitor:监控:要监控的数量。
  • verbose详细:详细模式,0或1。
  • save_best_only:如果save_best_only = True,则不会覆盖根据监控数量的最新最佳模型。
  • save_weights_only:如果为True,则只有模型的权重保存(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
  • mode:{auto,min,max}之一。 如果save_best_only =True,则根据监控数量的最大化或最小化来决定覆盖当前保存文件。
  • 对于val_acc,这应该是max,对于val_loss,这应该是min等。在自动模式下,从监控量的名称自动推断方向。
  • period:检查点之间的间隔(时期数)。

 有时候,我们需要保存训练过程中最好的结果,或者想先暂停训练后续再继续训练,这就需要用到checkpoint保存模型了。
save_best_only=True 保存最好的参数,默认False保存最后一个epoch的参数
save_weights_only=True 只保存参数,默认False 保存整个模型

	save_path = './checkpoint'
    if len(os.listdir(save_path)) == 0:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=save_path+'/train.ckpt',
            verbose=1,
            save_best_only=True,
            save_weights_only=True,
            period=1
        )
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test),callbacks=[cp_callback])
        plt.plot(history.epoch, history.history.get('acc'), label='acc')
        plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
        plt.legend()
        plt.show()
    else:
        model.load_weights(save_path+'/train.ckpt')

 在没有文件夹或者文件夹为空(尚未保存模型的时候)开始训练,训练过程如下,会自动保存更新损失值较低的,即模型效果更好的参数如下图,如果loss没下降,是不会更新的。
在这里插入图片描述
 保存完文件夹下内容如下:对比可发现在save_weights_only=Trued的情况下保存在文件夹里的 三个文件类型 与第一种方式相同。
在这里插入图片描述
如果不设置save_weights_only=True,那么保存的是一整个模型,文件格式如下图第一个所示,用法跟第二个保存完整模型结构类似。加载的时候,直接加载模型即可。加载模型路径是save_path+'train.ckpt'
在这里插入图片描述

save_path = './checkpoint/'
    if len(os.listdir(save_path)) == 0:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=save_path+'/train.ckpt',
            verbose=1,
            save_best_only=True,
            period=1
        )
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test),callbacks=[cp_callback])
        plt.plot(history.epoch, history.history.get('acc'), label='acc')
        plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
        plt.legend()
        plt.show()
    else:
        model = tf.keras.models.load_model(save_path+'train.ckpt')

4. 参考文献

tensorflow模型保存、读取与可训练参数提取
python中,判断文件是否存在的几种方法

  • 13
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

herry_drj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值