在服务器上运行网络时出现错误:
RuntimeError: Error(s) in loading state_dict for ReflectionNetwork:
原来:
checkpoint = torch.load("/Pre_train/checkpoint_81_test1.pth.tar",map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
修改后:
checkpoint = torch.load("/Pre_train/checkpoint_81_test1.pth.tar",map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'],False)
model.eval()
完结撒花!