当我们发现模型训练过程中异常,不得不停止,这是还像继续上次的训练就得都用到断点训练,保存模型的优化器,迭代器,模型,轮数等参数
保存模型,cfg为我的配置文件
checkpoint = {
'epoch':epoch,
'model':model.state_dict() if not cfg.more_gpu else model.module.state_dict(),
'optimizer':optimizer.state_dict(),
'lr_schedule':scheduler.state_dict(),
'best_acc':best_acc}
torch.save(checkpoint,cfg.checkpoint_path)
在训练前判断是否继续训练,若是则读取上一次参数,加载全部参数或者变换学习率等等
注释语句为断点训练更改参数,你觉得上次训练学习率过大就可以调小,配合日志功能来调整
start_epoch = -1
best_acc = 0.0
if cfg.resume:
path_checkpoint = cfg.checkpoint_path
checkpoint = torch.load(path_checkpoint)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model']) if not cfg.more_gpu else model.module.load_state_dict(checkpoint['model'])
model.model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
#optimizer = torch.optim.__dict__[cfg.optimizer_type](model.parameters(),lr = 0.0001,weight_decay = cfg.wd)
scheduler.load_state_dict(checkpoint['lr_schedule'])
#scheduler=lr_scheduler.ExponentialLR(optimizer=optimizer,gamma=0.9)
best_acc = checkpoint['best_acc']