关于Pytorch加载模型参数的避坑指南

一、load_state_dict(strict)中参数 strict的使用

load_state_dict(strict)中的参数strict默认是True,这时候就需要严格按照模型中参数的Key值来加载参数,如果增删了模型的结构层,或者改变了原始层中的参数,加载就会报错。

         相反地,如果设置strict为Flase,就可以只加载具有相同名称的参数层,对于修改的模型结构层进行随机赋值。这里需要注意的是,如果只是改变了原来层的参数,但是没有换名称,依然还是会报错。因为根据key值找到对应的层之后,进行赋值,发现参数不匹配。这时候可以将原来的层换个名称,再加载就不会报错了。最后,大家需要注意的是,strict=Flase要谨慎使用,因为很有可能你会一点参数也没加载进来,具体原因请看下文。

二、使用多GPU训练后的模型加载问题

        多GPU训练模型的好处不必多说,毕竟“钞能力”的力量不可小觑。但是,我们需要注意的是,如何加载多GPU训练的模型参数。在执行完函数model = nn.DataParallel(model, device_ids=[0,1,2,3])这条语句后,会给网络中所有的结构层的名称添加module这个字符,此时,如果我们直接使用 model.load_state_dict(torch.load("model.pth"),strict=True)将会报错,如果你灵机一动将strict的参数改为False,程序是不会报错了,但是测试结果会低到离谱,因为压根就没有参数加载进来,每一层的名称前都添加了module,所以名称都是不匹配的。

         这时候有两种解决问题的方法,一是在加载模型前,依旧使用model = nn.DataParallel (model, device_ids=[0,1,2,3])给模型每一层名称前添加module的字符。不过当我们想要单卡去测试模型时就遇到问题了,此时我们需要手动删除掉模型名称中的"module."这7个字符,注意是7个,还有个 .    这样做可以自由地更改模型参数的名称,不仅可以删减前缀"module. ",同时也能增加前缀,这个在模型拼接时会比较方便。

import torch
import torch.nn as nn
import Model.pvt_v2 as PvT
from collections import OrderedDict

net=PvT.pvt_v2_b4()# 
state_dict = torch.load("/datasets/Dset_Jerry/Checkpoint/CC-CXRI-P/PvT_B32_S384/PvT_18.pkl")  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain module.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove module.
    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。
# load params
net.load_state_dict(new_state_dict, strict=True)  # 重新加载这个模型。

  • 29
    点赞
  • 65
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
### 回答1: 函数。 PyTorch中的load_state_dict()函数可以把一个保存的模型参数加载模型中。它可以把模型参数从一个字典对象中加载模型中,也可以把参数从一个保存的文件中加载模型中。这个函数可以帮助用户从模型的训练中恢复,从而可以继续训练模型或者使用模型进行预测。 ### 回答2: pytorchload_state_dict()函数是用于加载预训练模型参数的方法。它接收一个state_dict参数state_dict是一个Python字典对象,包含了模型参数和对应的权重。load_state_dict()方法将这些参数加载到指定的模型中。 PyTorch模型通过一个有序字典对象来存储所有的模型参数,包括模型的可训练和不可训练参数state_dict中的key对应于模型的每个参数的名称,value则存储了相应参数的权重。 当我们进行模型训练,并在某个epoch达到最佳模型时,我们可以使用state_dict将这些训练好的参数保存下来。之后,当我们需要再次使用这个模型时,可以使用load_state_dict()加载这些保存下来的参数,以便于我们继续模型训练或进行预测。 使用load_state_dict()的基本步骤是: 1. 定义一个与要加载参数模型相同的模型对象; 2. 调用load_state_dict()方法并传入保存下来的state_dict; 3. 通过调用加载参数模型对象就可以使用保存下来的参数进行训练或者进行预测。 在加载参数时,需要注意参数的名称和结构应该与之前保存的参数保持一致,否则会导致参数加载失败。可以通过指定strict=False来允许一部分参数不存在。 除了加载整个模型参数,还可以通过使用load_state_dict()来加载模型的部分参数,例如只加载某个层的参数。 总而言之,pytorchload_state_dict()函数是一个用于加载预训练模型参数的重要工具,可以帮助我们在训练和预测中有效地管理和使用模型的权重。 ### 回答3: pytorch 中的 load_state_dict() 是一个模型加载函数,用于加载预训练模型参数。它是一个模型类的方法,可以通过调用模型对象的 load_state_dict() 函数来使用。 load_state_dict() 函数需要传入一个参数,即预训练模型的状态字典(state_dict)。状态字典是一个 Python 字典,它将每个模型参数名称映射到其对应的参数值。模型的状态字典可以通过模型对象的 state_dict() 函数来获取。 当我们在 PyTorch 中训练一个模型时,优化器会保存模型参数状态以及优化器的状态,方便下次恢复训练。load_state_dict() 函数将会加载预训练模型参数状态字典到当前模型中,以便我们可以从预训练模型中复制参数值。 在调用 load_state_dict() 函数之前,我们需要确保预训练模型和当前模型的网络结构是一致的,即它们具有相同的模型参数名字和参数形状。如果预训练模型和当前模型的网络结构不一致,load_state_dict() 函数会抛出错误。 一般来说,加载预训练模型的过程分为两个步骤。首先,我们创建一个空的模型对象,并根据预训练模型的网络结构进行初始化。然后,我们调用 load_state_dict() 函数来加载预训练模型参数。 总之,pytorch 中的 load_state_dict() 函数是一个方便的模型加载函数,它可以加载预训练模型参数。我们可以使用它来快速加载训练好的模型,以便进行推理、迁移学习等任务。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值