错误原因
save和load的代码不匹配.
解决办法
匹配即可,例如
case1
若save代码为
torch.save(network.cpu().state_dict(), model_name)
则load的代码应为
network.load_state_dict(torch.load(model_name))
case2
若save代码为
torch.save(network, model_name)
则load的代码应为
network.load_state_dict(torch.load(model_name).cpu().state_dict())