一般来说有两个原因
1. load错了模型
2. 多卡/单卡 混合train/test
如果是第二个原因,比如多卡训练(即使batchsize=1),然后测试时候模型虽然正确,但是也会报错,只要在网络load前加上就行了
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model1 = nn.DataParallel(model1)
model1.to(device)
# 在这之前加上多gpu就好了
model1.load_state_dict(torch.load(certain_scale_model + "/" + "model_{}.pth".format(i)))