最近写程序,遇到了保存和加载参数的问题,随通过查阅,留下笔记。
参数的保存
首先,参数的保存用的是 torch.save(),具体操作:
for epoch in range(num_epoch): #训练数据集的迭代次数,这里cifar10数据集将迭代2次
train_loss = 0.0
for batch_idx, data in enumerate(trainloader, 0):
#初始化
inputs, labels = data #获取数据
optimizer.zero_grad() #先将梯度置为0
#优化过程
outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputs
loss = criterion(outputs, labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失
loss.backward() #误差反向传播
optimizer.step() #随机梯度下降方法(之前定义)优化权重
#查看网络训练状态
train_loss += loss.item()
if batch_idx % 2000 == 1999: #每迭代2000个batch打印看一次当前网络收敛情况
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000))
train_loss = 0.0
print('Saving epoch %d model ...' % (epoch + 1))
#####参数保存###########
state = {
'net': net.state_dict(),
'epoch': epoch + 1,
} # 1 、 先建立一个字典
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint') # 2 、 建立一个保存参数的文件夹
torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作
# 因为在for epoch in range(num_epoch)这个循环中,所以可以 保存每一个epoch的参数,如果不在这个循环中,
#而是循环完成在保存,则保存的是最后一个epoch的参数
print('Finished Training')
结果如图所示
参数的加载
checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#载入现有模型
net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']
参考链接: https://blog.csdn.net/weixin_38145317/article/details/103582549.
这个链接写的很简单凝练,可以参考