【深度学习实战(51)】load_state_dict(strict)

一、load_state_dict(strict)中参数 strict的使用

load_state_dict(strict)中的参数strict默认是True,这时候就需要严格按照模型中参数的Key值来加载参数,如果增删了模型的结构层,或者改变了原始层中的参数,加载就会报错

相反地,如果设置strict为Flase,就可以只加载具有相同名称的参数层,对于修改的模型结构层进行随机赋值。这里需要注意的是,如果只是改变了原来层的参数,但是没有换名称,依然还是会报错。因为根据key值找到对应的层之后,进行赋值,发现参数不匹配。这时候可以将原来的层换个名称,再加载就不会报错了。最后,大家需要注意的是,strict=Flase要谨慎使用,因为很有可能你会一点参数也没加载进来,具体原因请看下文。

二、使用多GPU训练后的模型加载问题

多GPU训练模型的好处不必多说,毕竟“钞能力”的力量不可小觑。但是,我们需要注意的是,如何加载多GPU训练的模型参数。在执行完函数model = nn.DataParallel(model, device_ids=[0,1,2,3])这条语句后,会给网络中所有的结构层的名称添加module这个字符,此时,如果我们直接使用 model.load_state_dict(torch.load(“model.pth”),strict=True)将会报错,如果你灵机一动将strict的参数改为False,程序是不会报错了,但是测试结果会低到离谱,因为压根就没有参数加载进来,每一层的名称前都添加了module,所以名称都是不匹配的。

这时候有两种解决问题的方法,一是在加载模型前,依旧使用model = nn.DataParallel (model, device_ids=[0,1,2,3])给模型每一层名称前添加module的字符。不过当我们想要单卡去测试模型时就遇到问题了,此时我们需要手动删除掉模型名称中的"module."这7个字符,注意是7个,还有个 . 这样做可以自由地更改模型参数的名称,不仅可以删减前缀"module. ",同时也能增加前缀,这个在模型拼接时会比较方便。

import torch
import torch.nn as nn
import Model.pvt_v2 as PvT
from collections import OrderedDict
 
net=PvT.pvt_v2_b4()# 
state_dict = torch.load("/datasets/Dset_Jerry/Checkpoint/CC-CXRI-P/PvT_B32_S384/PvT_18.pkl")  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain module.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove module.
    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。
# load params
net.load_state_dict(new_state_dict, strict=True)  # 重新加载这个模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值