一、load_state_dict(strict)中参数 strict的使用
load_state_dict(strict)
中的参数strict
默认是True
,这时候就需要严格按照模型中参数的Key
值来加载参数,如果增删了模型的结构层,或者改变了原始层中的参数,加载就会报错。
相反地,如果设置stric
t为Flase
,就可以只加载具有相同名称的参数层,对于修改的模型结构层进行随机赋值。这里需要注意的是,如果只是改变了原来层的参数,但是没有换名称,依然还是会报错。因为根据key值找到对应的层之后,进行赋值,发现参数不匹配。这时候可以将原来的层换个名称,再加载就不会报错了。最后,大家需要注意的是,strict=Flase要谨慎使用,因为很有可能你会一点参数也没加载进来,具体原因请看下文。
二、使用多GPU训练后的模型加载问题
多GPU训练模型的好处不必多说,毕竟“钞能力”的力量不可小觑。但是,我们需要注意的是,如何加载多GPU训练的模型参数。在执行完函数model = nn.DataParallel(model, device_ids=[0,1,2,3])这条语句后,会给网络中所有的结构层的名称添加module这个字符,此时,如果我们直接使用 model.load_state_dict(torch.load(“model.pth”),strict=True)将会报错,如果你灵机一动将strict的参数改为False,程序是不会报错了,但是测试结果会低到离谱,因为压根就没有参数加载进来,每一层的名称前都添加了module,所以名称都是不匹配的。
这时候有两种解决问题的方法,一是在加载模型前,依旧使用model = nn.DataParallel (model, device_ids=[0,1,2,3])给模型每一层名称前添加module的字符。不过当我们想要单卡去测试模型时就遇到问题了,此时我们需要手动删除掉模型名称中的"module."这7个字符,注意是7个,还有个 . 这样做可以自由地更改模型参数的名称,不仅可以删减前缀"module. ",同时也能增加前缀,这个在模型拼接时会比较方便。
import torch
import torch.nn as nn
import Model.pvt_v2 as PvT
from collections import OrderedDict
net=PvT.pvt_v2_b4()#
state_dict = torch.load("/datasets/Dset_Jerry/Checkpoint/CC-CXRI-P/PvT_B32_S384/PvT_18.pkl") # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain module.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。
# load params
net.load_state_dict(new_state_dict, strict=True) # 重新加载这个模型。