pytorch保存模型、参数的方法
方法一:只保存模型的参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
filepath = 'model.dat'
torch.save(model.state_dict(), filepath)
model.load_state_dict(torch.load(filepath, map_location=device)
方法二:保存模型参数的同时保存其他训练相关状态,以便再次加载先前状态进行训练
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
state = torch.load(filepath)
epoch = state['epoch']
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
...
方法三: 保存整个模型(一般不建议使用)
torch.save(model, filepath)
model = torch.load(filepath)
参考链接: https://stackoverflow.com/a/49078976.