之前写的这篇文章内容不是很全面,今天组会师兄给予了指正并认真讲解,进而进行了相关的更新,见解可能不是很全面,如有问题恳请指正
关于这次更新主要有以下几方面的内容改进(更新于20200426)
- 对于多步长训练需要保存lr_schedule
- 初始化随机数种子
- 保存每一代最好的结果
最近在尝试用CIFAR10训练分类问题的时候,由于数据集体量比较大,训练的过程中时间比较长,有时候想给停下来,但是停下来了之后就得重新训练,之前师兄让我们学习断点继续训练及继续训练的时候注意epoch的改变等,今天上午给大致整理了一下,不全面仅供参考
Epoch: 9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018 s
Epoch: 9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216 s
Epoch: 9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398 s
Epoch: 9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921 s
Epoch: 9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974 s
Epoch: 9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034 s
Epoch: 9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831 s
绝望!!!!!训练到了一定次数发现训练次数少了,或者中途断了又得重新开始训练
一、模型的保存与加载
PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)
torch.save主要参数: obj:对象 、f:输出路径
torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu
模型的保存的两种方法:
1、保存整个Module
torch.save(net, path)
2、保存模型参数
state_dict = net.state_dict()
torch.save(state_dict , path)
二、模型的训练过程中保存
checkpoint = {
"net": model.state_dict(),
'optimizer':optimizer.state_dict(),
"epoch": epoch
}
将网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复
在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。
checkpoint = {
"net": model.state_dict(),
'optimizer':optimizer.state_dict(),
"epoch": epoch
}
if not os.pat