torch.nn.parallel.DistributedDataParallel的模型在进行eval()的时候必须加上with torch.no_grad(),否则就会导致rank==0的卡 卡死在运行eval()后的代码的过程中,而其他卡仍然在进行训练,其他卡不会等这个进行eval()的卡。
在使用中有一个地方很容易错误,代码如下:
if int(os.environ.get('RANK')) == 0:
with torch.no_grad():
# print('dd0')
if epoch % 10 == 0:
model.eval()
right_num = 0
# print('yy0')
for idx, (data, label) in enumerate(val_dataloader):
# print('yy1')
data = data.to(device)
# print('zz0')
label = label.to(device)
# print('zz1')
x0 = model(data)
# print('zz2')
x0 = torch.nn.functional.softmax(x0, dim=1)
# print('zz3')
# x0 = torch.nn.functional.sigmoid(x0)
# print(x0)
right_num += (torch.argmax(x0, dim=1) == label).sum().cpu().item()
# print('zz4')
# print('yy2')
# print('yy3')
if right_num >= right_num0:
# print('yy4')
right_num0 = right_num
# torch.save(model, "./best_dict_resnest101-softmax-64batch.pth")
torch.save(model.state_dict(), "./best_dict_resnest50-softmax-64batch-distr.pth")
# print('yy5')
print(right_num)
torch.cuda.empty_cache()
if int(os.environ.get('RANK')) == 0:
with torch.no_grad():
# print('dd1')
model.eval()
# torch.save(model, "./last_dict_resnest101-softmax-64batch.pth")
torch.save(model.state_dict(), "./last_dict_resnest50-softmax-64batch-distr.pth")