0 、报错信息:
1、查看model,chpt里面都有啥
weight=torch.load(weight_path,map_location="cpu")
查看weight["state_dict"]中的keys:发现里面的每一层权重都带有'module'
对比模型的state_dict:(model.state_dict().keys())。发现里面并没有‘module’ 字样,因此如果需要导入就必须将权重中的“module”去掉,或者增加Trans 中的“module”
2、此处选择将权重中 “module” 去掉,然后 load
from collections import OrderedDict
new_weight = OrderedDict()
for k,v in weight['state_dict'].items():
k_split = k.split('module.')[-1]
new_weight[k_split] = v
- 按理说这样就能导入成功,但是不行,检查下各自的keys,发现是false
new_weight.keys() == y1.state_dict().keys()
然后啊我就真的去检查俩哪不一样了呀,然后还真的有东西被打印出来了呀。但是咱师兄一模一样的东西为啥他就True,直接把我整自闭了。实验室同门都笑喷了,自己三行代码原模原样照抄都报错…..最后最后原来是版本有问题,torchvision 0.13 ,而我原本环境用的0.12.。以后记住别人能运行,而你不行:
1) 是不是哪里字符有问题
2) 是不是版本问题!!!!
最后附上整体代码:
model=Trans(image_size=256,patch_size=16,num_layers=12,num_heads=8,hidden_dim=768,mlp_dim=768*4,dropout=0.0,representation_size=768*2)、
weight_path = r'./checkpoint.pth.tar'
weight = torch.load(weight_path,map_location="cpu")
#新建一个权重orderdict
from collections import OrderedDict
new_weight = OrderedDict()
for k,v in weight['state_dict'].items():
k_split = k.split('module.')[-1]
new_weight[k_split] = v
model.load_state_dict(new_weight)
#可以检查哪项没有修改,如果new_weight.keys() == y1.state_dict().keys()
for index, k in enumerate(model.state_dict().keys()):
if list(new_weight.keys())[index] != k:
print(k,f"__",list(new_weight.keys())[index])