pytorch 中参数的保存(save),加载操作(load)

最近写程序,遇到了保存和加载参数的问题,随通过查阅,留下笔记。

参数的保存

首先,参数的保存用的是 torch.save(),具体操作:

for epoch in range(num_epoch):  #训练数据集的迭代次数,这里cifar10数据集将迭代2次
    train_loss = 0.0
    for batch_idx, data in enumerate(trainloader, 0):
        #初始化
        inputs, labels = data #获取数据
        optimizer.zero_grad() #先将梯度置为0
        
        #优化过程
        outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputs
        loss = criterion(outputs, labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失
        loss.backward() #误差反向传播
        optimizer.step() #随机梯度下降方法(之前定义)优化权重
        
        #查看网络训练状态
        train_loss += loss.item()
        if batch_idx % 2000 == 1999: #每迭代2000个batch打印看一次当前网络收敛情况
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000))
            train_loss = 0.0
    
    print('Saving epoch %d model ...' % (epoch + 1))
    #####参数保存###########
    state = {
        'net': net.state_dict(),
        'epoch': epoch + 1,
    }                                 # 1 、 先建立一个字典
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')       # 2 、 建立一个保存参数的文件夹
    torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作
    # 因为在for epoch in range(num_epoch)这个循环中,所以可以 保存每一个epoch的参数,如果不在这个循环中,
    #而是循环完成在保存,则保存的是最后一个epoch的参数

print('Finished Training')

结果如图所示
在这里插入图片描述

参数的加载

checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#载入现有模型
net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']

参考链接: https://blog.csdn.net/weixin_38145317/article/details/103582549.
这个链接写的很简单凝练,可以参考
在这里插入图片描述

  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值