pytorch模型的保存和加载

在训练深度学习模型的过程中,周期性地对模型做存档(Checkpoint)非常重要。一方面,深度模型的训练一般是一个长期的过程,训练过程中会出现各种问题,比如硬件错误或者断电等;另一方面,训练好的模型后需要对实际数据进行预测(Predict,或称为推理Inference),这时候就需要把模型的权重保存到硬盘中,方便后续直接调用模型进行预测。

1. 模块和张量的序列化及反序列化

pytorch的一系列方法,可以将torch.nn.Module和torch.tensor类等的实例转换成字符串,这些实例可以通过Python序列化(Serialization)和反序列化(Deserialization)。pytorch集成了python自带的pickle包对模块和张量进行序列化。

1.1 pytorch保存和载入模型
import torch

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

其中,torch.save函数传入的第一个参数是pytoch中的可以被序列化的对象,包括模型和张量;第二个参数是存储文件的路径;第三个参数是默认的,传入的是序列化的库,pytorch带序列化库pickle;第四个参数是pickle协议。
torch.load函数的第一个参数是文件路径;第二个参数是张量存储位置的映射,设置map_location='cpu’可以GPU设备号错误问题;pickle_load_args用来指定传给pickle_module.load的参数。

1.2 pytorch状态字典的保存和载入

在pytorch中,一般模型可以有两种保存方式,第一种是直接保存模型的实例,第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字(nn.Module类在初始化模型的时候会自动给所有的参数都分配一个名字),以及名字对应的张量。通过调用一个模型的state_dict方法,可以获取当前模型的状态字典。

import torch
import torch.nn as nn

class LinearModel(nn.Module):
    def __init__(self, ndim):
        super(LinearModel, self).__init__()
        self.ndim = ndim
        
        self.weight = nn.Parameter(torch.randn(ndim, 1))  # 定义权重
        self.bias = nn.Parameter(torch.randn(1))  # 定义偏置
        
    def forward(self, x):
        # y = Wx + b
        return x.mm(self.weight) + self.bias
    
lm = LinearModel(5)
print(lm.state_dict())  # 获取状态字典
print('=' * 50)
t = lm.state_dict()  # 保存状态字典
lm = LinearModel(5)  # 重新定义线性模型
print(lm.state_dict())  # 新的状态字典,模型参数和原来的不同
print('=' * 50)
lm.load_state_dict(t)  # 载入原来的状态字典
print(lm.state_dict())  # 模型参数已更新

在这里插入图片描述

2. 模块状态字典的保存和载入

在pytorch中,使用某一个版本pytorch保存的模块序列化文件,无法被另一个版本的pytorch载入。相比之下,pytorch张量的实现变动较小,而状态字典含有张量参数的名字和张量参数具体的信息,与模块的关联较小。所以,可以调用state_dict方法来获取状态字典,然后保存该张量字典来保存模型,这样可以最大限度地减少代码对pytorch版本的依赖性。一般一个pytorch训练的检查点的信息包含模型的状态、优化器的状态和当前迭代的步数、损失函数和准确率平均值等。

save_info = {  # 保存的信息
    'iter_num': iter_num,  # 迭代步数
    'optimizer': optimizer.state_dict(),  # 优化器的状态字典
    'model': model.state_dict()  # 模型的状态字典
}
# 保存信息
torch.save(save_info, save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_state_dict(save_info['optimizer'])
model.load_state_dict(save_info['model'])
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

饕餮&化骨龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值