RuntimeError: Error(s) in loading state_dict for RNN:问题解决

RuntimeError: Error(s) in loading state_dict for RNN:问题解决

当我们在进行机器学习时,会出现加载训练好的模型参数出现错误的情况,以下便是我遇到的情况。
在这里插入图片描述
在这里插入图片描述
从上图可见,在加载训练好的模型参数时出现了问题,报错意思很明显,即加载的参数名称和新实例化的模型参数名称不一样。
首先,查看已训练好的模型参数,如下图所示。
在这里插入图片描述
而新实例化的模型参数如下图所示。
在这里插入图片描述
很明显,参数列表的名称不一样。
我怀疑我没有保存进去,所以重新训练一遍模型,结果如下所示:

在这里插入图片描述
果然。。。。。。

接下来介绍几种pytorch保存并加载模型的方法。

  • 方法一:加载/保存整个模型

保存:torch.save(model, PATH)
加载:model = torch.load(PATH)
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

  • 方法二:加载和保存一个通用的检查点(Checkpoint)
    保存的示例代码:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载的示例代码:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar 。

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值