深度学习模型权重参数的保存和加载(包括选择最优权重)---基于tensorflow

文章介绍了在Tensorflow中如何保存和加载模型的三种方式:1)仅保存权重参数;2)保存模型结构和权重参数,通常保存为.h5文件;3)在训练过程中使用checkpoint保存最佳模型。此外,还提到了ModelCheckpoint回调函数在训练过程中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


  我们创建好模型之后需要保存模型,以方便后续对模型的读取与调用,保存模型我们可能有下面三种需求:1、 只保存模型权重参数;2、 同时保存模型图结构与权重参数;3、 在训练过程的检查点保存模型数据。下面分别对这三种需求进行实现

1. 只保存模型权重参数

  • Model.save_weights(file_path)  # 将文件保存到save_path
  • Model.load_weights(file_path)   # 将文件读取到save_path

应用场景如下:

注意保存路径的设置是./weight/ 后面的/千万不要少,要不然可能会出现保存路径的问题

在这里插入图片描述
在前面模型已经compile完毕,在训练结束后,保存一下模型的参数。保存的文件夹是save_path的文件夹,要提前设置好,并且是空的,保存完毕后在save_path文件夹下会生成以下三个文件

在这里插入图片描述

file_path可以设置为主文件mnist.py文件夹下的空文件夹,比如上图中的weight用于保存模型参数。
在使用的时候,将模型结构设置为相同后,直接加载参数:Model.load_weights(file_path) 即可使用,但是注意模型结构要相同且已经经过compile才可以,要不然会报错。等于只是省略了model.fit的过程.

2. 保存模型结构和权重参数

  • 保存模型
    model.save(‘net_model.h5’)

  • 模型加载
    new_model=tf.keras.models.load_model(‘net_model.h5’)

Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

保存完整的模型非常有用,使我们可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。
  
所以,我们保存整个模型的时候,保存文件的后缀一般都是.h5

应用场景如下:

# 训练模型
    save_path = 'net_model.h5'
    if os.path.exists(save_path) == False:
        # 优化器
        adam_optimizer = tf.keras.optimizers.Adam(learning_rate, )
        # 编译模型
        model.compile(optimizer=adam_optimizer,
                      loss=tf.keras.losses.sparse_categorical_crossentropy,
                      metrics=['acc'])
        # 模型开始训练
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test))
        model.save('net_model.h5')
    else:
        model = tf.keras.models.load_model('net_model.h5')

 可以看到在与主文件相同的文件夹下产生了一个以.h5结尾的文件net_model.h5,这个就是保存的完整的模型及参数。一旦这个文件存在的话,那么就可以直接加载这个模型及其参数,不需要训练。如下第二个图,可以加载到已经存下来的网络结构,和上轮训练的最后一轮的参数
在这里插入图片描述
在这里插入图片描述
在这里补充一下,判断一个文件是否存在的方法 以及 判断一个文件夹是否为空的方法:

  • 判断一个文件是否存在的方法os.path.exists(test_file.txt)=False说明该文件夹不存在

if os.path.exists(test_file.txt) == False # 该文件不存在的情况下
print(‘目标文件不存在’)

  • 判断一个文件夹是否为空:len(os.listdir(tar_dir)==0说明改文件夹下为空

if len(os.listdir(tar_dir)) == 0: # 目标文件夹内容为空的情况下
print(“目标文件夹为空”)

3. 在训练过程中检查点checkpoint保存模型权重参数或者结构

tf.keras.callbacks.ModelCheckpoint(参数如下)

  • filepath:string,保存模型文件的路径。
  • monitor:监控:要监控的数量。
  • verbose详细:详细模式,0或1。
  • save_best_only:如果save_best_only = True,则不会覆盖根据监控数量的最新最佳模型。
  • save_weights_only:如果为True,则只有模型的权重保存(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
  • mode:{auto,min,max}之一。 如果save_best_only =True,则根据监控数量的最大化或最小化来决定覆盖当前保存文件。
  • 对于val_acc,这应该是max,对于val_loss,这应该是min等。在自动模式下,从监控量的名称自动推断方向。
  • period:检查点之间的间隔(时期数)。

 有时候,我们需要保存训练过程中最好的结果,或者想先暂停训练后续再继续训练,这就需要用到checkpoint保存模型了。
save_best_only=True 保存最好的参数,默认False保存最后一个epoch的参数
save_weights_only=True 只保存参数,默认False 保存整个模型

	save_path = './checkpoint'
    if len(os.listdir(save_path)) == 0:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=save_path+'/train.ckpt',
            verbose=1,
            save_best_only=True,
            save_weights_only=True,
            period=1
        )
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test),callbacks=[cp_callback])
        plt.plot(history.epoch, history.history.get('acc'), label='acc')
        plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
        plt.legend()
        plt.show()
    else:
        model.load_weights(save_path+'/train.ckpt')

 在没有文件夹或者文件夹为空(尚未保存模型的时候)开始训练,训练过程如下,会自动保存更新损失值较低的,即模型效果更好的参数如下图,如果loss没下降,是不会更新的。
在这里插入图片描述
 保存完文件夹下内容如下:对比可发现在save_weights_only=Trued的情况下保存在文件夹里的 三个文件类型 与第一种方式相同。
在这里插入图片描述
如果不设置save_weights_only=True,那么保存的是一整个模型,文件格式如下图第一个所示,用法跟第二个保存完整模型结构类似。加载的时候,直接加载模型即可。加载模型路径是save_path+'train.ckpt'
在这里插入图片描述

save_path = './checkpoint/'
    if len(os.listdir(save_path)) == 0:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=save_path+'/train.ckpt',
            verbose=1,
            save_best_only=True,
            period=1
        )
        history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
                            validation_data=(x_test, y_test),callbacks=[cp_callback])
        plt.plot(history.epoch, history.history.get('acc'), label='acc')
        plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
        plt.legend()
        plt.show()
    else:
        model = tf.keras.models.load_model(save_path+'train.ckpt')

4. 参考文献

tensorflow模型保存、读取与可训练参数提取
python中,判断文件是否存在的几种方法

<think>嗯,用户想了解Excel中的VLOOKUP函数的用法参数含义。首先,我需要回忆一下VLOOKUP的基本结构。VLOOKUP是垂直查找函数,用于在表格的首列查找指定的值,然后返回该行中指定列的数据。它的四个参数分别是查找值、表格范围、列索引号匹配模式。 接下来,我得确认每个参数的具体作用。第一个参数是lookup_value,也就是用户要查找的值,比如产品ID或者姓名。第二个参数是table_array,即查找的范围,这里要注意的是,查找值必须位于这个范围的第一列,否则函数无法正确工作。第三个参数是col_index_num,也就是用户想要返回的数据在表格范围中的第几列,这个列数是从表格范围的第一列开始算起的,是整个工作表的列。第四个参数是range_lookup,决定是精确匹配还是近似匹配,通常用FALSE或0来指定精确匹配,避免错误。 然后,用户可能需要一个具体的例子来理解如何使用。比如,假设有一个产品价格表,用户想根据产品ID查找对应的价格。这时候,VLOOKUP就可以在ID列查找,然后返回价格列的数据。同时,要提醒用户注意表格范围最好使用绝对引用,比如$A$2:$C$100,这样在拖动公式时范围会变化。 可能遇到的问题也需要考虑。比如,如果查找值在表格的第一列,或者列索引号超过了表格的列数,函数会返回错误。此外,数据格式一致也可能导致问题,比如查找值是数字,而表格中是文本格式的数字,这时候需要统一格式。还有,当出现#N/A错误时,可能是查找值存在,或者匹配模式设置正确,需要检查参数是否正确。 用户可能还会问到如何避免常见的错误,比如部分数据匹配出错的情况。这时候需要检查是否有多余的空格,或者数据格式一致,或者表格范围没有锁定导致拖动时偏移。此外,XLOOKUP作为VLOOKUP的升级版,可能在某些情况下更方便,比如可以向左查找或者返回多列数据,这也是可以提及的内容。 最后,整理步骤,从基础语法到参数说明,再到具体案例,逐步引导用户理解如何使用VLOOKUP函数。同时,提供常见错误解决方法,帮助用户在实际应用中避免问题。如果有必要,还可以比较VLOOKUP其他查找函数的区别,比如HLOOKUP或XLOOKUP,但根据用户的问题,主要集中在VLOOKUP本身即可。</think>### Excel VLOOKUP函数详解 #### 一、基础语法与参数说明 VLOOKUP函数的语法为: $$=VLOOKUP(lookup\_value, table\_array, col\_index\_num, [range\_lookup])$$ 包含4个参数: 1. **lookup_value**(必填):要查找的值(如单元格引用或具体值) 2. **table_array**(必填):包含数据的表格范围(必须包含查找列返回列) 3. **col_index_num**(必填):返回值所在列的序号(从table_array第一列开始计数) 4. **range_lookup**(可选):匹配类型 - `TRUE`/`1`:近似匹配(默认值,需数据升序排列) - `FALSE`/`0`:精确匹配(常用选项) [^1][^2] #### 二、使用步骤演示(工资表查询案例) 假设需要根据员工编号查询工资: 1. 建立查询单元格(如`B12`) 2. 输入公式: ```excel =VLOOKUP(A12, $A$2:$D$100, 4, 0) ``` - `A12`:待查询的员工编号 - `$A$2:$D$100`:锁定数据区域(绝对引用) - `4`:返回第4列(工资列) - `0`:精确匹配 [^2][^3] #### 三、常见错误与解决方法 | 错误现象 | 原因 | 解决方案 | |---------|------|---------| | #N/A | 查找值存在 | 检查数据源或改用`IFERROR`容错 | | #REF! | 列序号超出范围 | 确认col_index_num ≤ 表格列数 | | 部分匹配失败 | 数据格式一致 | 统一数值/文本格式 | | 结果错位 | 表格未锁定 | 使用`$`符号固定区域引用 | [^3][^4] #### 四、进阶技巧 1. **多条件查询**: 使用辅助列合并多个条件字段 ```excel =VLOOKUP(A2&B2, $D$2:$F$100, 3, 0) ``` 2. **通配符匹配**: `"*"`匹配任意字符,`"?"`匹配单个字符 ```excel =VLOOKUP("张*", $A$2:$C$100, 3, 0) ``` 3. **跨表查询**: 引用其他工作表数据 ```excel =VLOOKUP(A2, Sheet2!$A$2:$D$100, 4, 0) ``` [^1][^4]
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

herry_drj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值