12-pytorch中内置网络模型

pytorch中内置网络模型

Docs -> torchvision -> Models and pre-trained weights

MODELS AND PRE-TRAINED WEIGHTS

The torchvision.models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow.

如:VGG(一种分类模型)

最常用的有vgg16和vgg19

VGG16

torchvision.models.vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, *kwargs: Any*) → VGG[SOURCE]

VGG-16 from Very Deep Convolutional Networks for Large-Scale Image Recognition.

Parameters:

  • weights (VGG16_Weights, optional) – The pretrained weights to use. See VGG16_Weights below for more details, and possible values. By default, no pre-trained weights are used.

  • progress (bool, optional) – If True, displays a progress bar of the download to stderr. Default is True.

  • **kwargs – parameters passed to the torchvision.models.vgg.VGG base class. Please refer to the source code for more details about this class.

注意

pretrained参数被弃用(deprecated)了,需要更新为weights参数

UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.

在旧版本的写法 pretrained = True 中,对于预训练权重参数我们没有太多选择的余地,一执行起来就要使用默认的预训练权重文件版本。但问题是,现在深度学习的发展日新月异,很快就有性能更强的模型横空出世。

而使用新版本写法 weights=预训练模型参数版本 ,相当于我们掌握了预训练权重参数文件的选择权。我们就可以尽情地使用更准更快更强更新的预训练权重参数文件,帮助我们的研究更上一层楼。

示例
from torchvision import models
​
# 加载精度为76.130%的旧权重参数文件V1
model_v1 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# 等价写法
model_v1 = models.resnet50(weights="IMAGENET1K_V1")
​
# 加载精度为80.858%的新权重参数文件V2
model_v2 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# 等价写法
model_v1 = models.resnet50(weights="IMAGENET1K_V2")
​
# 如果你不知道哪个版本是最新, 直接选择默认DEFAULT即可
model_new = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

示例

import torchvision
from torch import nn
​
# # 注意,该模型现无法公开访问,download参数已弃用
# train_data = torchvision.datasets.ImageNet(root='./dataset', split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())
​
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')
vgg16_false = torchvision.models.vgg16(weights=None)
​
print(vgg16_true)
​
# 该模型输出的分类数是1000,如何修改使之应用到别的数据集上呢?
train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, download=True,
                                          transform=torchvision.transforms.ToTensor())
# vgg16_true.add_module('add_linear', nn.Linear(in_features=1000, out_features=10))
vgg16_true.classifier.add_module('add_linear', nn.Linear(in_features=1000, out_features=10))
print(vgg16_true)
​
vgg16_false.classifier[6] = nn.Linear(in_features=4096, out_features=10)
print(vgg16_false)

预训练后的结果

未预训练的结果

模型的保存和加载

保存

import torch
import torchvision
from torch import nn
​
vgg16 = torchvision.models.vgg16(weights=None)
​
# 保存方式1
# 保存网络模型结构+网络内部参数
torch.save(vgg16, f=r"./models/vgg16_method1.pth")
​
# 保存方式2(官方推荐,占用空间小)
# 只保存网络模型的参数,不保存结构
torch.save(vgg16.state_dict(), f=r"./models/vgg16_method2.pth")
​
# 方式1的陷阱
class XiaoMo(nn.Module):
    def __init__(self):
        super(XiaoMo, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
​
    def forward(self, x):
        return self.conv1(x)
​
​
xiaomo = XiaoMo()
torch.save(xiaomo, f=r"./models/xiaomo.pth")

加载

import torch
import torchvision
​
# 保存方式1->加载模型
model = torch.load("./models/vgg16_method1.pth")
print(model)
​
# 保存方式2->加载模型模型
# vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)  # 先新建网络模型结构
vgg16 = torchvision.models.vgg16(weights=None)  # 先新建网络模型结构
vgg16.load_state_dict(torch.load("./models/vgg16_method2.pth"))  # 加载模型参数
# model = torch.load("./models/vgg16_method2.pth")
print(vgg16)
​
# 方式1保存的陷阱
model = torch.load("./models/xiaomo.pth")
print(model)
# AttributeError: Can't get attribute 'XiaoMo' on <module '__main__' from 'C:\\Users\\31058\\Desktop\\MyProject\\ML\\test\\20_model_load.py'>
# 原因是类未声明,需要重新声明或者导入(一般采用导入的方法)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值