使用Pytorch进行深度学习,保存和加载模型

59 篇文章 4 订阅
46 篇文章 3 订阅
本文介绍了如何在PyTorch中加载预训练模型,包括完整加载、仅加载结构、部分加载参数等方法。以ResNet50为例,展示了预训练模型在ImageNet上的优秀表现,并提供了加载预训练权重的代码示例。此外,还讨论了如何保存和加载整个模型以及分别加载网络结构和参数的操作。
摘要由CSDN通过智能技术生成


模型分为几类,一种是自己写好的模型,一种一些成熟网络模型来做迁移或者预训练。

1.预训练模型加载

本文以resnet50网络为例。先对网络模型简单介绍。
https://arxiv.org/pdf/1512.03385v1.pdf
摘要翻译:更深层次的神经网络训练更加困难。我们提出一个 Residual的学习框架来缓解训练的网比之前所使用的网络深得多。我们提供全面的经验证据显示这些残余网络更容易优化,并可以从显着增加的深度获得准确性。在ImageNet数据集上我们评估深度达152层残留网比VGG网[41]更深,但复杂度仍然较低。这些残留网络的集合实现了3.57%的误差在ImageNet测试集上。这个结果赢得了ILSVRC 2015分类任务第一名。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

import torch
import torchvision
# prepare model
mode1_resnet50 = torchvision.models.resnet50(pretrained=True)

这种会同时加载模型和参数

2)只加载模型,不加载预训练参数

# 导入模型结构
resnet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练参数到resnet18
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))

加载部分预训练模型

resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
   也可以直接从官方model_zoo下载:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

2.简易加载已有模型

# 保存和加载整个模型
torch.save(model_object, 'net.pth')
model = torch.load('net.pth')

这也是pytorch官网推荐的一种加载方式,容易上手。但是无法做到模型和超参数分开。

3.分别加载网络的结构和参数

# 将my_resnet模型储存为net.pth
torch.save(my_resnet.state_dict(), "net.pth")
# 加载net,模型存放在net.pth
my_resnet.load_state_dict(torch.load("net.pth"))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值