在Pytorch官方文档中,给出了很多现成的网络架构供我们使用。
例如 分类的 AlexNet,VGG, ResNet等
In the official documentation of Pytorch, many ready-made network architectures are given for us to use.
For example,
For classification: AlexNet, VGG, ResNet, etc.
语义分割的 FCN,ResNet50等
For semantic segmentation: FCN, ResNet50, etc.
目标检测的 SSD,RetinaNet等
实例分割的Mask R-CNN等
For object detection: SSD, RetinaNet, etc.
For instance segmentation: Mask R-CNN, etc.
这个网络架构,一般最常用的参数有两个:一个是Pretrained有无经过预训练,如果False,表示没有经过预训练,如果True,表示经过某个数据集的预训练;
另一个是progress有无进度条,如果True,表示显示进度条,如果False,表示不显示进度条;
There are two most commonly used parameters for those network architectures: one is whether Pretrained has been pre-trained.
If it is False, it means that it has not been pre-trained. If it is True, it means that it has been pre-trained on a certain data set;
The other is whether the progress has a progress bar. If it is True, it means that the progress bar is displayed. If it is False, it means that the progress bar is not displayed;
现在,以VGG16为例,观察pretrained参数
(因为VGG16的预训练是使用ImageNet,下载ImageNet_2012_train需要100多G,所以就这样直接看参数)
Now, taking VGG16 as an example, we can observe the pretrained parameters
(Because the pre-training of VGG16 uses ImageNet,
it takes more than 100 G to download ImageNet_2012_train,
so just look at the parameters directly)
上代码(code)
:
import torchvision
vgg16_false = torchvision.models.vgg16(pretrained=False, progress=True)
print(vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True, progress=True)
print(vgg16_true)
print("ok")
断点(breakpoint)
结果:
以vgg16为例,使用CIFAR数据集,但是此数据集只有10类,由于vgg16的预训练数据集是ImageNet有1000个类,所以需要修改原本的vgg16网络,这里有很多种的修改方式。
Taking vgg16 as an example, the CIFAR dataset is used, but this dataset has only 10 categories.
Since the pre-training dataset of vgg16 is ImageNet with 1000 categories, it is necessary to modify the original vgg16 network.
There are many ways to modify it.
method1 在最后增加一层全连接层,输入为1000输出为10;
method2 修改最后一层全连接层,将1000个类改为是10各类[则输入是4096,输出10]
method1 adds a fully connected layer at the end, the input is 1000 and the output is 10;
method2 Modify the last fully connected layer and change the 1000 classes to 10 classes [the input is 4096, the output is 10].
code
:
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
tran_tensor = transforms.ToTensor()
dataset = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=tran_tensor, download=True)
dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0, drop_last=False)
# method1 to modify vgg16
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_false.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_false)
print("***********************************************")
# method2 to modify vgg16
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_true.classifier[6] = nn.Linear(4096, 10)
print(vgg16_true)
result
:
------------------------------------------------------------------------------------------------------------上一章 15.初识Pytorch反向传播(Backward)与优化器(optimizer)SGD
下一章 17.初识Pytorch保存模型(model save)与加载模型(model load)