1. 问题描述:
我在使用torch.save()保存了optimizer的参数过后,
torch.save(
{
'state_dict':net.state_dict(),
'optimizer':optimizer.state_dict(),
'epochID':epoch,
},
filename
)
再次利用optimizer.load_state_dict()加载参数,在optimizer.step()处报错:
RuntimeError: expected device cpu but got device cuda:0
2. 解决办法:
重载optimizer的参数时将所有的tensor都放到cuda上(加载时默认放在cpu上了),代码片段如下:
checkpoint = torch.load(filename)
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
current_epoch = checkpoint['epochID'] + 1
顺利解决问题~