网络参数加载问题
之前训练网络是在多GPU上用训练的.
torch.nn.DataParallel(net, device_ids)
训练后保存的网络参数
torch.save(net.state_dict(), save_path)
测试时我用的单GPU测试,但是加载网络会报错,没有参数
Missing key(s) in state_dict: “…”
net.load_state_dict(torch.load(model_path)
解决方法
多GPU训练的网络,加载也要用
torch.nn.DataParallel(net, device_ids)