保存模型时,将epoch model optimizer参数都保存下来
checkpoint = {
'epoch': self.step,
'model': self.refiner.state_dict(),
'optimizer': self.optim.state_dict(),
}
self.save(cfg.ckpt_dir, cfg.ckpt_name, checkpoint)
def save(self, ckpt_dir, ckpt_name,checkpoint):
#print 'sfdfsdsfsdf'
save_path = os.path.join(
ckpt_dir, "{}_{}.pth".format(ckpt_name, self.step))
torch.save(checkpoint, save_path)
#torch.save(self.refiner.state_dict(), save_path) #只保存模型参数
#torch.save(self.refiner, save_path) 保存整个模型
在训练开始时加载保存的模型,训练模型参数初始化值使用读取的参数
if cfg.resume > 0:
checkpoint = torch.load(cfg.resume_dir)
self.optim.load_state_dict(checkpoint['optimizer'])
self.refiner.load_state_dict(checkpoint['model'])
self.step = checkpoint['epoch']+1