模型分为几类,一种是自己写好的模型,一种一些成熟网络模型来做迁移或者预训练。
1.预训练模型加载
本文以resnet50网络为例。先对网络模型简单介绍。
https://arxiv.org/pdf/1512.03385v1.pdf
摘要翻译:更深层次的神经网络训练更加困难。我们提出一个 Residual的学习框架来缓解训练的网比之前所使用的网络深得多。我们提供全面的经验证据显示这些残余网络更容易优化,并可以从显着增加的深度获得准确性。在ImageNet数据集上我们评估深度达152层残留网比VGG网[41]更深,但复杂度仍然较低。这些残留网络的集合实现了3.57%的误差在ImageNet测试集上。这个结果赢得了ILSVRC 2015分类任务第一名。
import torch
import torchvision
# prepare model
mode1_resnet50 = torchvision.models.resnet50(pretrained=True)
这种会同时加载模型和参数
2)只加载模型,不加载预训练参数
# 导入模型结构
resnet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练参数到resnet18
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
加载部分预训练模型
resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
也可以直接从官方model_zoo下载:
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
2.简易加载已有模型
# 保存和加载整个模型
torch.save(model_object, 'net.pth')
model = torch.load('net.pth')
这也是pytorch官网推荐的一种加载方式,容易上手。但是无法做到模型和超参数分开。
3.分别加载网络的结构和参数
# 将my_resnet模型储存为net.pth
torch.save(my_resnet.state_dict(), "net.pth")
# 加载net,模型存放在net.pth
my_resnet.load_state_dict(torch.load("net.pth"))