RuntimeError: Error(s) in loading state_dict for DABNet:
Missing key(s) in state_dict: "init_conv.0.conv.weight", "init_conv.0.bn_prelu.bn.weight", "init_conv.0.bn_prelu.bn.bias","classifier.0.conv.weight".
Unexpected key(s) in state_dict: "epoch", "model".等一大堆的错误
源码:
import torch
from model.DABNet import DABNet
model_dir = "./all_model_dir/dataset/mic/DABNetbs20gpu1_trainval/model_100.pth"
example = torch.rand(1,3,400,400)
net = DABNet(classes=2)
state_dict = torch.load(model_dir, map_location='cpu')
net.load_state_dict(state_dict)
net.eval()
traced_script_module = torch.jit.trace(net, example)
traced_script_module.save("traced_DABnet_model.pt")
报错原因:
是因为训练环境和到处环境不一致
更改
net.load_state_dict(state_dict)
为
net.load_state_dict(state_dict,FALSE)
即可解决