RuntimeError: Error(s) in loading state_dict for Trans: Missing key(s) in state_dict: “class_token“

0 、报错信息:

 1、查看model,chpt里面都有啥

weight=torch.load(weight_path,map_location="cpu")

 查看weight["state_dict"]中的keys:发现里面的每一层权重都带有'module'

 对比模型的state_dict:(model.state_dict().keys())。发现里面并没有‘module’ 字样,因此如果需要导入就必须将权重中的“module”去掉,或者增加Trans 中的“module

 2、此处选择将权重中 “module” 去掉,然后 load

from collections import OrderedDict
new_weight = OrderedDict()
for k,v in weight['state_dict'].items():
    k_split = k.split('module.')[-1]
    new_weight[k_split] = v
  1. 按理说这样就能导入成功,但是不行,检查下各自的keys,发现是false
new_weight.keys() == y1.state_dict().keys()  

然后啊我就真的去检查俩哪不一样了呀,然后还真的有东西被打印出来了呀。但是咱师兄一模一样的东西为啥他就True,直接把我整自闭了。实验室同门都笑喷了,自己三行代码原模原样照抄都报错…..最后最后原来是版本有问题,torchvision 0.13 ,而我原本环境用的0.12.。以后记住别人能运行,而你不行:

1 是不是哪里字符有问题

2 是不是版本问题!!!!

最后附上整体代码:

model=Trans(image_size=256,patch_size=16,num_layers=12,num_heads=8,hidden_dim=768,mlp_dim=768*4,dropout=0.0,representation_size=768*2)、
weight_path = r'./checkpoint.pth.tar'
weight  = torch.load(weight_path,map_location="cpu")
#新建一个权重orderdict
from collections import OrderedDict
new_weight = OrderedDict()
for k,v in weight['state_dict'].items():
    k_split = k.split('module.')[-1]
    new_weight[k_split] = v
   
model.load_state_dict(new_weight)
#可以检查哪项没有修改,如果new_weight.keys() == y1.state_dict().keys()
for index, k in enumerate(model.state_dict().keys()):
    if  list(new_weight.keys())[index] != k:
    print(k,f"__",list(new_weight.keys())[index])

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值