1. 方式一: 加载和保存状态字典
只保存参数的方法在加载的时候要事先定义好跟原模型一致的模型,并在该模型的实例对象(假设名为model)上进行加载,即在使用上述加载语句前已经有定义了一个和原模型一样的Net, 并且进行了实例化 model=Net( ) .
## 保存模型的字典参数, 保存模型的后缀通常为'.pt' 或者 ‘.pth’
torch.save(model.state_dict(), PATH) ## PATH必须要有文件名称如./model.pth
## 加载模型
model = init() ## 意思为每一个不同类别的模型,需要先初始化,这里并不会加载模型,只是申请了一个网络框架;
model.load_state_dict(torch.load(PATH))
model.eval()
2. 方式二: 加载和保存整个模型
## 保存整个模型
torch.save(model, PATH)
## 加载模型
model = torch.load(PATH)
model.eval()
3. 方式三: 保存训练的某个checkpoint
## a. 先建立一个字典,保存三个参数:
state = {‘model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
## b. 保存checkpoints
torch.save(state, PATH)
## c. 加载checkpoints去继续训练或预测
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
4. pytorch训练显示进度条的工具
tqdm
https://github.com/tqdm/tqdm