解决方法一:
load_state_dict(torch.load(‘net.pth’)在前,增加
model = nn.DataParallel(model)
就可以了。
例如:
net = NET()
net.cuda()
net = nn.DataParallel(net)
net.load_state_dict(torch.load('net.pth')
如果还不行可以考虑是pytorch版本换成大于1.0.0(小于0.4.0),若果是别人训练好的权重还应该考虑是不是自己的模型和别人的权重文件不匹配