[记录]PyTorch加载预训练权重时可能存在的问题

本文主要记录下我在复现CNN经典模型中加载官方预训练权重需要注意的点以及常见的错误。

1、常规加载预训练权重

        本节所涉及的方法必须保证模型中每一层的名字与预训练权重的对应层名字相同

        法1:

        weights=torch.load(opt.weight,map_location=device)
        net.load_state_dict(weights)

        opt.weight是预训练权重的路径;

        法2:

net=model.to(device)
ckpt=torch.load(weight_path, map_location=device)
ckpt={k: v for k, v in ckpt.items() if net.state_dict()[k].numel() == v.numel()}

        (个人感觉)如果没有对模型做任何修改,法一更加简洁;

2、复现代码后加载预训练权重可能存在的问题

        本节的前提是不修改任意一层的结构(即网络整体的结构是不变的)。相信不少小伙伴和我一样在自己复现完代码之后,总会在加载预训练权重报错。出现这种情况无外乎两种原因:搭建的网络的层名字与预训练权重名字不对应;搭建的网络中出现了多余的层(例如,官方实现时卷积层没有偏置,而你实现过程中加了偏置;或者是搭建的网络中出现了冗余即空的列表)。

        针对名字不对应的情况,可以通过以下代码完成权重加载:

        

from collections import OrderedDict
ckpt=torch.load(weight_path, map_location=device)
net=model.to(device)    #加载模型
module_lst=[i for i in net.state_dict()]    #搭建网络的所有层的名字
weights=OrderedDict()    #创建一个有序字典,存放权重
for idx,(k,v) in enumerate(ckpt.items()):
   if net.state_dict()[module_lst[idx]].numel()==v.numel():    
        weights[module_lst[idx]]=v    #如果对应层参数量相同,保存当前层以及对应参数
net.load_state_dict(weights, strict=False)

        如果上述方式仍然报错的话,就需要检查我们搭建的网络的层数是否与预训练权重的相同(我在检查时候通常会保存到两个txt文件中方便对比;如果整体结构没修改的话,问题通常会出现在一些小的点,例如多加了偏置或者出现了冗余的层),下面就我在复现EfficientNet时加载预训练权重出现的问题做简要说明。

        下图是预训练权重对应层以及参数以及我搭建的模型对应的参数:

        通过对比两个字典,发现元素个数不相同=>我们的模型参数数量要多于预训练权重。仔细一看发现我们模型的每个卷积层都多了一行(我也不太清楚为什么会多一行,并且里面没有任何参数):

         通过下面代码将模型中所有的num_batches_tracked删除掉:

ckpt=torch.load(opt.weight,map_location=device)
weights = collections.OrderedDict()
net_layer=net.state_dict()
new_model_dict = {}
for layer_name,val in net_layer.items():
   if "num_batches_tracked" in layer_name:
        pass
   else:
        new_model_dict[layer_name] = val
module_lst = [i for i in new_model_dict]
for idx, (k, v) in enumerate(ckpt.items()):
    if net.state_dict()[module_lst[idx]].numel() == v.numel():
        weights[module_lst[idx]] = v
net.load_state_dict(weights,strict=False)

        经过处理之后,发现参数量与预训练权重的相同。如果处理之后仍然不相同那就要去细心比对了(通常是由于偏置引起的)   

        关于更改网络结构后预训练权重如何加载以及能否使用预训练权重,由于目前未涉及到该方面,后续需要的话会做记录。    

 

### 可能原因分析 ResNet预训练权重文件无法加载通常由以下几个常见因素引起: 1. **网络连接问题**:如果设备无法访问互联网或者存在防火墙限制,则可能导致无法成功下载预训练权重文件[^3]。 2. **URL失效或变更**:官方提供的预训练权重链接可能会因为版本更新或其他原因而发生更改,这将导致程序尝试从错误的地址获取数据。 3. **本地缓存损坏**:有已下载到本地的模型参数可能出现意外中断而导致部分数据丢失或损坏,进而影响后续使用过程中的正常读取操作。 ### 解决方案概述 针对上述提到的各种可能性,可以采取如下措施逐一排查并解决问题: #### 方法一:确认网络环境畅通无阻 确保当前运行环境中具备良好的外部网络连通能力,并且没有任何安全策略阻止对于特定域名称下的资源请求行为。可以通过简单命令测试目标服务器可达状态以及端口开放情况: ```bash ping download.pytorch.org telnet download.pytorch.org 443 ``` 以上两条指令分别用于验证主机存活与否及其HTTPS服务是否可用。假如发现异常现象,则需联系管理员调整设置或是切换至其他允许完全上网的位置再试一次整个流程。 #### 方法二:手动指定新的有效下载源路径 当默认配置里的url字段指向的内容已经不可用,就需要查找最新发布的替代选项并将之替换进去。例如,在PyTorch框架下实现自定义修改model_urls字典内容的方式如下所示: ```python import torchvision.models as models models.model_urls['resnet50'] = 'http://other-mirror.com/path/to/new/resnet50.pth' pretrained_model = models.resnet50(pretrained=True) ``` 这里假设我们找到了一个新的可靠的镜像站点存放有相同结构类型的权值文档,那么只需按照上面示范代码片段那样重新赋值给对应键名即可完成适配工作。 #### 方法三:清除旧版残留记录重试自动安装机制 有候即使一切看起来都准备就绪却依旧报错提示找不到合适的匹配项,这候不妨试着清理掉之前可能存在的干扰因素——即删除`.cache`目录里边存储的相关临文件夹后再执行一遍初始化逻辑看看效果怎样: Linux/MacOS平台: ```bash rm -rf ~/.cache/torch/hub/checkpoints/ ``` Windows平台: ```cmd del /s /q %USERPROFILE%\.cache\torch\hub\checkpoints\ rmdir /s /q %USERPROFILE%\.cache\torch\hub\checkpoints\ ``` 之后再次调用创建实例函数应该就能顺利拉取下来所需的依赖组件了[^2]。 --- ### 总结说明 综上所述,面对ResNet系列架构所附带的预训练参数包难以顺利完成加载这一状况,可以从检查基础通信条件出发逐步深入挖掘潜在隐患所在位置;与此同灵活运用多种手段相结合的办法积极应对直至彻底消除障碍恢复正常运作秩序为止。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值