pytorch保存模型_【他山之石】一文读懂 PyTorch 模型保存与载入

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者: 知乎—涤生

地址:https://www.zhihu.com/people/wen-xiao-du-4

当我们设计好一个模型,投喂大量数据,并且经过艰苦训练,终于得到了一个表现不错的模型,得到这个模型以后,就要保存下来,为后面的部署或者测试实验做准备。那么该如何正确高效的保存好训练的模型并再次重新载入呢? 首先来定义一个简单的模型
class SampleNet(nn.Module):    def __init__(self):        super(SampleNet, self).__init__()        self.conv1 = nn.Conv2d(3, 16, 3)        self.bn1 = nn.BatchNorm2d(16)        self.relu = nn.ReLU()        self.pool = nn.MaxPool2d(2, 2)        self.conv2 = nn.Conv2d(16, 32, 3)        self.bn2 = nn.BatchNorm2d(32)        self.gap = nn.AdaptiveAvgPool2d(1)        self.fc = nn.Linear(32, 10)    def forward(self, x):        x = self.conv1(x)        x = self.bn1(x)        x = self.relu(x)        x = self.pool(x)        x = self.conv2(x)        x = self.bn2(x)        x = self.relu(x)        x = self.gap(x)        x = x.view(-1, 32)        x = self.fc(x)        return xclass SampleNet2(nn.Module):    def __init__(self):        super(SampleNet2, self).__init__()        self.conv1 = nn.Conv2d(3, 16, 3)        self.bn1 = nn.BatchNorm2d(16)        self.relu = nn.ReLU()        self.pool = nn.MaxPool2d(2, 2)        self.conv2 = nn.Conv2d(16, 32, 3)        self.bn2 = nn.BatchNorm2d(32)        self.conv3 = nn.Conv2d(32, 32, 3)        self.bn3 = nn.BatchNorm2d(32)        self.gap = nn.AdaptiveAvgPool2d(1)        self.fc = nn.Linear(32, 10)    def forward(self, x):        x = self.conv1(x)        x = self.bn1(x)        x = self.relu(x)        x = self.pool(x)        x = self.conv2(x)        x = self.bn2(x)        x = self.relu(x)        x = self.conv3(x)        x = self.bn3(x)        x = self.gap(x)        x = x.view(-1, 32)        x = self.fc(x)        return

只保存模型参数

首先,PyTorch 官方推荐的第一种方式是只保存模型的参数。 对于一个卷积网络模型来说,模型的卷积层、BN层是有经过训练得到的参数的,只需要把对应每一层的参数存储起来,就可以再次加载模型。 而模型的参数存储在一个字典中,通过 `model.state_dict()` 即可得到。
# 保存 PyTorch 模型参数torch.save(model.state_dict(), "model.pt")# 重新载入模型参数# 首先定义模型model = SampleNet()#通过 load_state_dict 函数加载参数,torch.load() 函数中重要的一步是反序列化。model.load_state_dict(torch.load("model.pt"))
利用这种方式,可以更加灵活的利用训练好的模型,比如我需要再次训练的时候,改变来模型,但是只加载部分模型参数。

加载部分模型参数

# SampleNet2 比 SampleNet 多一个卷积层model = SampleNet2()# load_params 和 model_params 分别为两个模型的参数字典load_params = torch.load("model.pt")model_params = model.state_dict()# 构建一个新参数字典,为两个模型重复的部分same_parsms = {k: v for k, v in load_params.items() if k in model_params.keys()}# 更新模型参数字典,并载入model_params.update(same_parsms)model.load_state_dict(model_params)

保存加载全部模型

torch.save(model, "full_model.pt")new_model = torch.load("full_model.pt")

跨设备加载模型

在GPU上保存,在CPU上加载

一般模型训练都是在GPU设备,保存后能在GPU设备上加载运行,而若想在CPU设备上加载,只需在load 函数中加一个map_location参数即可。
device = torch.device('cpu')model = SampleNet()model.load_state_dict(torch.load("model.pt", map_location=device))

在CPU上保存,在GPU上加载

反过来,在也只需要改变map_location参数,但要注意将模型也对应到相同设备。
device = torch.device("cuda")model = SampleNet()model.load_state_dict(torch.load("model.pt", map_location="cuda:0"))model.to(device)

保存 torch.nn.DataParallel 模型

如果是要保存在单机多GPU上训练的模型,则需要特别注意一下。
model = SampleNet()model = torch.nn.DataParallel(model)torch.save(model.module.state_dict(), "model")
以上就是 PyTorch 在保存模型时的一些方式和技巧,希望能够帮助到你~

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

bf0420f196a6c8b479051352807fb838.gif

直播预告

bf0420f196a6c8b479051352807fb838.gif 4907e7b31ea0c74f7f921c4215f5fb39.png

他山之石”历史文章

  • 适合PyTorch小白的官网教程:Learning PyTorch With Examples

  • pytorch量化备忘录

  • LSTM模型结构的可视化

  • PointNet论文复现及代码详解

  • SCI写作常用句型之研究结果&发现

  • 白话生成对抗网络GAN及代码实现

  • pytorch的余弦退火学习率

  • Pytorch转ONNX-实战篇(tracing机制)

  • 联邦学习:FedAvg 的 Pytorch 实现

  • PyTorch实现ShuffleNet-v2亲身实践

  • 训练时显存优化技术——OP合并与gradient checkpoint

  • 浅谈数据标准化与Pytorch中NLLLoss和CrossEntropyLoss损失函数的区别

  • 在C++平台上部署PyTorch模型流程+踩坑实录

  • libtorch使用经验

  • 深度学习模型转换与部署那些事(含ONNX格式详细分析)

更多他山之石专栏文章,

请点击文章底部“阅读原文”查看

02767cfbfcf5473e07ad8fca728425fd.png e3fd61129a18e952072fb80b9db111c3.gif

分享、点赞、在看,给个三连击呗!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值