训练模型代码如下:
net = ResNet18().cuda()
net = nn.DataParallel(net)
……
torch.save(net.state_dict(), 'path')
加载保存的模型参数代码如下:
net = ResNet18().cuda()
net.load_state_dict(torch.load('path'))
net = nn.DataParallel(net)
随后运行便会报出文章标题的错误,报错原因是调用了DataParallel函数,保存的模型参数包含了.module键,所以解决方法有两种,可以先将模型包装到 nn.DataParallel 中,然后加载 state_dict,或者保存net.module.state_dict(),具体操作如下:
第一种:将加载保存的模型参数代码改成:
net = ResNet18().cuda()
net = nn.DataParallel(net)
net.load_state_dict(torch.load('path'))
第二种:将训练模型的代码中的torch.save语句改成:
torch.save(net.module.state_dict(), 'path')