pytorch中内置网络模型
Docs -> torchvision -> Models and pre-trained weights
MODELS AND PRE-TRAINED WEIGHTS
The torchvision.models
subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow.
如:VGG(一种分类模型)
最常用的有vgg16和vgg19
VGG16
torchvision.models.vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, *kwargs: Any*) → VGG[SOURCE]
VGG-16 from Very Deep Convolutional Networks for Large-Scale Image Recognition.
Parameters:
-
weights (VGG16_Weights, optional) – The pretrained weights to use. See VGG16_Weights below for more details, and possible values. By default, no pre-trained weights are used.
-
progress (bool, optional) – If True, displays a progress bar of the download to stderr. Default is True.
-
**kwargs – parameters passed to the
torchvision.models.vgg.VGG
base class. Please refer to the source code for more details about this class.
注意
pretrained参数被弃用(deprecated)了,需要更新为weights参数
UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
在旧版本的写法 pretrained = True 中,对于预训练权重参数我们没有太多选择的余地,一执行起来就要使用默认的预训练权重文件版本。但问题是,现在深度学习的发展日新月异,很快就有性能更强的模型横空出世。
而使用新版本写法 weights=预训练模型参数版本 ,相当于我们掌握了预训练权重参数文件的选择权。我们就可以尽情地使用更准更快更强更新的预训练权重参数文件,帮助我们的研究更上一层楼。
示例
from torchvision import models # 加载精度为76.130%的旧权重参数文件V1 model_v1 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) # 等价写法 model_v1 = models.resnet50(weights="IMAGENET1K_V1") # 加载精度为80.858%的新权重参数文件V2 model_v2 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) # 等价写法 model_v1 = models.resnet50(weights="IMAGENET1K_V2") # 如果你不知道哪个版本是最新, 直接选择默认DEFAULT即可 model_new = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
示例
import torchvision from torch import nn # # 注意,该模型现无法公开访问,download参数已弃用 # train_data = torchvision.datasets.ImageNet(root='./dataset', split='train', download=True, # transform=torchvision.transforms.ToTensor()) vgg16_true = torchvision.models.vgg16(weights='DEFAULT') vgg16_false = torchvision.models.vgg16(weights=None) print(vgg16_true) # 该模型输出的分类数是1000,如何修改使之应用到别的数据集上呢? train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, download=True, transform=torchvision.transforms.ToTensor()) # vgg16_true.add_module('add_linear', nn.Linear(in_features=1000, out_features=10)) vgg16_true.classifier.add_module('add_linear', nn.Linear(in_features=1000, out_features=10)) print(vgg16_true) vgg16_false.classifier[6] = nn.Linear(in_features=4096, out_features=10) print(vgg16_false)
预训练后的结果
未预训练的结果
模型的保存和加载
保存
import torch import torchvision from torch import nn vgg16 = torchvision.models.vgg16(weights=None) # 保存方式1 # 保存网络模型结构+网络内部参数 torch.save(vgg16, f=r"./models/vgg16_method1.pth") # 保存方式2(官方推荐,占用空间小) # 只保存网络模型的参数,不保存结构 torch.save(vgg16.state_dict(), f=r"./models/vgg16_method2.pth") # 方式1的陷阱 class XiaoMo(nn.Module): def __init__(self): super(XiaoMo, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3) def forward(self, x): return self.conv1(x) xiaomo = XiaoMo() torch.save(xiaomo, f=r"./models/xiaomo.pth")
加载
import torch import torchvision # 保存方式1->加载模型 model = torch.load("./models/vgg16_method1.pth") print(model) # 保存方式2->加载模型模型 # vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) # 先新建网络模型结构 vgg16 = torchvision.models.vgg16(weights=None) # 先新建网络模型结构 vgg16.load_state_dict(torch.load("./models/vgg16_method2.pth")) # 加载模型参数 # model = torch.load("./models/vgg16_method2.pth") print(vgg16) # 方式1保存的陷阱 model = torch.load("./models/xiaomo.pth") print(model) # AttributeError: Can't get attribute 'XiaoMo' on <module '__main__' from 'C:\\Users\\31058\\Desktop\\MyProject\\ML\\test\\20_model_load.py'> # 原因是类未声明,需要重新声明或者导入(一般采用导入的方法)