简要模板
- save
torch.save({"state_dict": model.state_dict(), 'optimizer':optimzer.state_dict()}, 'my_checkpoint.pth.tar')
- load
model.load_state_dict(torch.load(‘my_checkpoint.pth.tar’, device))
optimizer.load_state_dict(torch.load(‘my_checkpoint.pth.tar’, device))
一般模型
保存模型
def save_checkpoint(state, filename='my_checkpoint.pth.tar'):
print("=> saving checkpoint")
torch.save(state, filename)
if epoch % 3 == 0:
checkpoint = {"state_dict": model.state_dict(), 'optimizer':optimzer.state_dict()}
save_checkpoint(checkpoint)
加载模型
def load_checkpoint(checkpoint):
print("=>loading checkpoint")
model.load_state_dict(copy.deepcopy(checkpoint["state_dict"]))
optimzer.load_state_dict(copy.deepcopy(checkpoint['optimizer']))
if load_model:
load_checkpoint(torch.load('my_checkpoint.pth.tar', device))
注意两点
- 1
checkpoint = {“state_dict”: model.state_dict(), ‘optimizer’:optimzer.state_dict()} 中state_dict不要漏掉括号
如果忘记写括号,会报如下错误
state_dict = state_dict.copy()
AttributeError: 'function' object has no attribute 'copy'
- 2
torch.load(‘my_checkpoint.pth.tar’, device) 注意加入device, 其中
device = torch.device(‘cuda’ if torch.cuda.is_available() else “cpu”)