# 如
model = model.restnet50()
# 如果有保存好的训练文件,在后面加上下面几句话
resume = 'checkpoint-480.pth'
checkpoint = torch.load(resume)
model.load_state_dict (checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
保存模型的方法:
# 定义要保留的格式及数据
def save_checkpoint(path, model, optimizer):
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, path)
# 在一个epoch结束后,写上:
save_checkpoint('checkpoint-%i.pth' % index, model, optimizer)
# optimizer举例
optimizer = torch.optim.SGD(
model.parameters(),
lr=LEARNING_RATE,
momentum=MOMENTUM
)