pytorch在对以经训练好的网络进行测试时发现以上问题,在装载网络后添加model.eval(),问题得到解决。
model = nn.DataParallel(model).cpu()
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
model.eval()
此代码段有三个需要注意的地方
nn.DataParallel(model)
如果在训练网络时加入了这个函数,在使用训练好的网络时也需要将此函数加入
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
若在保存网络时没有保存网络框架只有参数,可以通过以上方法加载网络,若训练网络时使用GPU,使用时在CPU上则需添加参数map_location=torch.device('cpu')
model.eval()
第三个则为题目的问题,在加载完网络后添加以上函数,则可解决题目所遇到的问题