优点:
1.长时间的训练,如果发生中断,继续训练时直接读取
2.通过迁移学习,利用别人训练好的数据进行训练,提高训练效果
三个方面说明
1.模型保存与加载
2.冻结一部分参数,训练另一部分参数
3.采用不同的学习率进行训练
模型保存与加载的三种方式
# 方式一:保存与加载整个state_dict(推荐)
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH)) # 继承自torch.nn.Module.load_state_dict
# 测试时不启用BatchNormalization 和DropOut
model.eval()
# 方式二:保存加载整个模型
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
torch.save({
'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,
...}, PATH)
checkpoint = torch.load(PATH)
start_epoch = checkpoint['epoch']
model.load_state_dict(chechpoint['model_state_dict'])
# 训练时
model.train()
# 测试时
model.eval